diff --git a/tests/nexus_api_test.go b/tests/nexus_api_test.go index d47d66599e2..90b3ae9bfa5 100644 --- a/tests/nexus_api_test.go +++ b/tests/nexus_api_test.go @@ -11,7 +11,6 @@ import ( "github.com/nexus-rpc/sdk-go/nexus" "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" nexuspb "go.temporal.io/api/nexus/v1" @@ -25,6 +24,7 @@ import ( "go.temporal.io/server/common/metrics/metricstest" commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/common/nexus/nexusrpc" + "go.temporal.io/server/common/testing/parallelsuite" "go.temporal.io/server/components/nexusoperations" "go.temporal.io/server/service/frontend/configs" "go.temporal.io/server/tests/testcore" @@ -50,22 +50,18 @@ func newHeaderCaptureCaller() (func(*http.Request) (*http.Response, error), *hea var op = nexus.NewOperationReference[string, string]("my-operation") type NexusApiTestSuite struct { - NexusTestBaseSuite + parallelsuite.Suite[*NexusApiTestSuite] } func TestNexusApiTestSuiteWithLegacyErrorPaths(t *testing.T) { - t.Parallel() - suite.Run(t, new(NexusApiTestSuite)) + parallelsuite.Run(t, &NexusApiTestSuite{}, false) // useTemporalFailures = false } func TestNexusApiTestSuiteWithTemporalFailures(t *testing.T) { - t.Parallel() - s := new(NexusApiTestSuite) - s.useTemporalFailures = true - suite.Run(t, s) + parallelsuite.Run(t, &NexusApiTestSuite{}, true) // useTemporalFailures = true } -func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { +func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes(useTemporalFailures bool) { callerLink := &commonpb.Link_WorkflowEvent{ Namespace: "caller-ns", WorkflowId: "caller-wf-id", @@ -94,67 +90,67 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { asyncSuccessEndpoint := testcore.RandomizeStr("test-endpoint") operationErrorOutcome := "operation_error" - if s.useTemporalFailures { + if useTemporalFailures { operationErrorOutcome = "failure" } type testcase struct { name string outcome string - endpoint *nexuspb.Endpoint + endpointName string timeout time.Duration handler nexusTaskHandler - assertion func(*testing.T, *nexusrpc.ClientStartOperationResponse[string], error, http.Header) + assertion func(*NexusApiTestSuite, *nexusrpc.ClientStartOperationResponse[string], error, http.Header) onlyByEndpoint bool } testCases := []testcase{ { - name: "sync_success", - outcome: "sync_success", - endpoint: s.createNexusEndpoint(testcore.RandomizeStr("test-endpoint"), testcore.RandomizeStr("task-queue")), - handler: nexusEchoHandler, - assertion: func(t *testing.T, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { - require.NoError(t, err) - require.Equal(t, "input", res.Successful) + name: "sync_success", + outcome: "sync_success", + endpointName: testcore.RandomizeStr("test-endpoint"), + handler: nexusEchoHandler, + assertion: func(s *NexusApiTestSuite, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { + s.NoError(err) + s.Equal("input", res.Successful) }, }, { name: "async_success", outcome: "async_success", onlyByEndpoint: true, - endpoint: s.createNexusEndpoint(asyncSuccessEndpoint, testcore.RandomizeStr("task-queue")), - handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { + endpointName: asyncSuccessEndpoint, + handler: func(t *testing.T, res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { // Choose an arbitrary test case to assert that all of the input is delivered to the // poll response. - s.Equal(asyncSuccessEndpoint, res.Request.Endpoint) + require.Equal(t, asyncSuccessEndpoint, res.Request.Endpoint) start := res.Request.Variant.(*nexuspb.Request_StartOperation).StartOperation - s.Equal(op.Name(), start.Operation) - s.Equal("http://localhost/callback", start.Callback) - s.Equal("request-id", start.RequestId) - s.Equal("value", res.Request.Header["key"]) - s.NotContains(res.Request.Header, "temporal-nexus-failure-support") - s.Len(start.GetLinks(), 1) - s.Equal(callerNexusLink.URL.String(), start.Links[0].GetUrl()) - s.Equal(callerNexusLink.Type, start.Links[0].Type) + require.Equal(t, op.Name(), start.Operation) + require.Equal(t, "http://localhost/callback", start.Callback) + require.Equal(t, "request-id", start.RequestId) + require.Equal(t, "value", res.Request.Header["key"]) + require.NotContains(t, res.Request.Header, "temporal-nexus-failure-support") + require.Len(t, start.GetLinks(), 1) + require.Equal(t, callerNexusLink.URL.String(), start.Links[0].GetUrl()) + require.Equal(t, callerNexusLink.Type, start.Links[0].Type) return &nexusTaskResponse{ StartResult: &nexus.HandlerStartOperationResultAsync{OperationToken: "test-token"}, Links: []nexus.Link{handlerNexusLink}, }, nil }, - assertion: func(t *testing.T, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { - require.NoError(t, err) - require.Equal(t, "test-token", res.Pending.Token) - require.Len(t, res.Links, 1) - require.Equal(t, handlerNexusLink.URL.String(), res.Links[0].URL.String()) - require.Equal(t, handlerNexusLink.Type, res.Links[0].Type) + assertion: func(s *NexusApiTestSuite, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { + s.NoError(err) + s.Equal("test-token", res.Pending.Token) + s.Len(res.Links, 1) + s.Equal(handlerNexusLink.URL.String(), res.Links[0].URL.String()) + s.Equal(handlerNexusLink.Type, res.Links[0].Type) }, }, { - name: "operation_error", - outcome: operationErrorOutcome, - endpoint: s.createNexusEndpoint(testcore.RandomizeStr("test-endpoint"), testcore.RandomizeStr("task-queue")), - handler: func(_ *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { + name: "operation_error", + outcome: operationErrorOutcome, + endpointName: testcore.RandomizeStr("test-endpoint"), + handler: func(_ *testing.T, _ *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { return nil, &nexus.OperationError{ State: nexus.OperationStateFailed, Cause: &nexus.FailureError{ @@ -166,43 +162,43 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { }, } }, - assertion: func(t *testing.T, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { + assertion: func(s *NexusApiTestSuite, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { var operationError *nexus.OperationError - require.ErrorAs(t, err, &operationError) - require.Equal(t, nexus.OperationStateFailed, operationError.State) - if s.useTemporalFailures { + s.ErrorAs(err, &operationError) + s.Equal(nexus.OperationStateFailed, operationError.State) + if useTemporalFailures { // Through the Temporal failure round-trip, the cause chain has an extra wrapper // for the OperationError's ApplicationFailureInfo. var failureErr *nexus.FailureError - require.ErrorAs(t, operationError.Cause, &failureErr) + s.ErrorAs(operationError.Cause, &failureErr) var innerErr *nexus.FailureError - require.ErrorAs(t, failureErr.Cause, &innerErr) + s.ErrorAs(failureErr.Cause, &innerErr) tFailure, err := commonnexus.NexusFailureToTemporalFailure(innerErr.Failure) - require.NoError(t, err) + s.NoError(err) convErr := temporal.GetDefaultFailureConverter().FailureToError(tFailure) var appErr *temporal.ApplicationError - require.ErrorAs(t, convErr, &appErr) - require.Equal(t, "deliberate test failure", appErr.Message()) + s.ErrorAs(convErr, &appErr) + s.Equal("deliberate test failure", appErr.Message()) var details nexus.Failure - require.NoError(t, appErr.Details(&details)) - require.Equal(t, "v", details.Metadata["k"]) + s.NoError(appErr.Details(&details)) + s.Equal("v", details.Metadata["k"]) } else { - require.Equal(t, "deliberate test failure", operationError.Cause.Error()) + s.Equal("deliberate test failure", operationError.Cause.Error()) var failureErr *nexus.FailureError - require.ErrorAs(t, operationError.Cause, &failureErr) - require.Equal(t, map[string]string{"k": "v"}, failureErr.Failure.Metadata) + s.ErrorAs(operationError.Cause, &failureErr) + s.Equal(map[string]string{"k": "v"}, failureErr.Failure.Metadata) var details string err = json.Unmarshal(failureErr.Failure.Details, &details) - require.NoError(t, err) - require.Equal(t, "details", details) + s.NoError(err) + s.Equal("details", details) } }, }, { - name: "handler_error", - outcome: "handler_error:INTERNAL", - endpoint: s.createNexusEndpoint(testcore.RandomizeStr("test-endpoint"), testcore.RandomizeStr("task-queue")), - handler: func(_ *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { + name: "handler_error", + outcome: "handler_error:INTERNAL", + endpointName: testcore.RandomizeStr("test-endpoint"), + handler: func(_ *testing.T, _ *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { return nil, &nexus.HandlerError{ Type: nexus.HandlerErrorTypeInternal, Cause: &nexus.FailureError{ @@ -210,22 +206,22 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { }, } }, - assertion: func(t *testing.T, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { + assertion: func(s *NexusApiTestSuite, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { var handlerErr *nexus.HandlerError - require.ErrorAs(t, err, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) - require.Equal(t, nexus.HandlerErrorRetryBehaviorUnspecified, handlerErr.RetryBehavior) - require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) - require.Empty(t, handlerErr.Message) - require.Error(t, handlerErr.Cause) - require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeInternal, handlerErr.Type) + s.Equal(nexus.HandlerErrorRetryBehaviorUnspecified, handlerErr.RetryBehavior) + s.Equal("worker", headers.Get("Temporal-Nexus-Failure-Source")) + s.Empty(handlerErr.Message) + s.Error(handlerErr.Cause) + s.Equal("deliberate internal failure", handlerErr.Cause.Error()) }, }, { - name: "handler_error_non_retryable", - outcome: "handler_error:INTERNAL", - endpoint: s.createNexusEndpoint(testcore.RandomizeStr("test-endpoint"), testcore.RandomizeStr("task-queue")), - handler: func(_ *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { + name: "handler_error_non_retryable", + outcome: "handler_error:INTERNAL", + endpointName: testcore.RandomizeStr("test-endpoint"), + handler: func(_ *testing.T, _ *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { return nil, &nexus.HandlerError{ Type: nexus.HandlerErrorTypeInternal, RetryBehavior: nexus.HandlerErrorRetryBehaviorNonRetryable, @@ -234,45 +230,53 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { }, } }, - assertion: func(t *testing.T, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { + assertion: func(s *NexusApiTestSuite, res *nexusrpc.ClientStartOperationResponse[string], err error, headers http.Header) { var handlerErr *nexus.HandlerError - require.ErrorAs(t, err, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) - require.Equal(t, nexus.HandlerErrorRetryBehaviorNonRetryable, handlerErr.RetryBehavior) - require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) - require.Empty(t, handlerErr.Message) - require.Error(t, handlerErr.Cause) - require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeInternal, handlerErr.Type) + s.Equal(nexus.HandlerErrorRetryBehaviorNonRetryable, handlerErr.RetryBehavior) + s.Equal("worker", headers.Get("Temporal-Nexus-Failure-Source")) + s.Empty(handlerErr.Message) + s.Error(handlerErr.Cause) + s.Equal("deliberate internal failure", handlerErr.Cause.Error()) }, }, { - name: "handler_timeout", - outcome: "handler_timeout", - endpoint: s.createNexusEndpoint(testcore.RandomizeStr("test-service"), testcore.RandomizeStr("task-queue")), - timeout: 2 * time.Second, - handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { + name: "handler_timeout", + outcome: "handler_timeout", + endpointName: testcore.RandomizeStr("test-service"), + timeout: 2 * time.Second, + handler: func(t *testing.T, res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { timeoutStr, set := res.Request.Header[nexus.HeaderRequestTimeout] - s.True(set) + require.True(t, set) timeout, err := time.ParseDuration(timeoutStr) var dispatchTimeoutBuffer = nexusoperations.MinDispatchTaskTimeout.Get(dynamicconfig.NewNoopCollection())("test") expectedMaxTimeout := 2*time.Second - dispatchTimeoutBuffer - s.LessOrEqual(timeout, expectedMaxTimeout, "timeout should be buffered") + require.LessOrEqual(t, timeout, expectedMaxTimeout, "timeout should be buffered") - s.NoError(err) + require.NoError(t, err) time.Sleep(timeout) //nolint:forbidigo // Allow time.Sleep for timeout tests return nil, nil }, - assertion: func(t *testing.T, res *nexusrpc.ClientStartOperationResponse[string], err error, header http.Header) { + assertion: func(s *NexusApiTestSuite, res *nexusrpc.ClientStartOperationResponse[string], err error, header http.Header) { var handlerErr *nexus.HandlerError - require.ErrorAs(t, err, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeUpstreamTimeout, handlerErr.Type) - require.Equal(t, "upstream timeout", handlerErr.Message) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeUpstreamTimeout, handlerErr.Type) + s.Equal("upstream timeout", handlerErr.Message) }, }, } - testFn := func(t *testing.T, tc testcase, dispatchURL string) { + testFn := func(s *NexusApiTestSuite, tc testcase, dispatchOnlyByEndpoint bool) { + env := newNexusTestEnv(s.T(), useTemporalFailures, testcore.WithDedicatedCluster()) + endpoint := env.createNexusEndpoint(s.T(), tc.endpointName, testcore.RandomizeStr("task-queue")) + var dispatchURL string + if dispatchOnlyByEndpoint { + dispatchURL = getDispatchByEndpointURL(env.HttpAPIAddress(), endpoint.Id) + } else { + dispatchURL = getDispatchByNsAndTqURL(env.HttpAPIAddress(), env.Namespace().String(), endpoint.Spec.Target.GetWorker().TaskQueue) + } ctx, cancel := context.WithCancel(testcore.NewContext()) defer cancel() @@ -282,11 +286,11 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { Service: "test-service", HTTPCaller: httpCaller, }) - require.NoError(t, err) - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + s.NoError(err) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - pollerErrCh := s.nexusTaskPoller(ctx, tc.endpoint.Spec.Target.GetWorker().TaskQueue, tc.handler) + pollerErrCh := env.nexusTaskPoller(ctx, s.T(), endpoint.Spec.Target.GetWorker().TaskQueue, tc.handler) eventuallyTick := 500 * time.Millisecond header := nexus.Header{"key": "value", "temporal-nexus-failure-support": "true"} @@ -309,23 +313,23 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { return err == nil || !(errors.As(err, &handlerErr) && handlerErr.Type == nexus.HandlerErrorTypeNotFound) }, 10*time.Second, eventuallyTick) - tc.assertion(t, result, err, headerCapture.lastHeaders) + tc.assertion(s, result, err, headerCapture.lastHeaders) s.NoError(<-pollerErrCh) snap := capture.Snapshot() - require.Len(t, snap["nexus_requests"], 1) - require.Subset(t, snap["nexus_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "method": "StartNexusOperation", "outcome": tc.outcome}) - require.Contains(t, snap["nexus_requests"][0].Tags, "nexus_endpoint") - require.Equal(t, int64(1), snap["nexus_requests"][0].Value) - require.Equal(t, metrics.MetricUnit(""), snap["nexus_requests"][0].Unit) + s.Len(snap["nexus_requests"], 1) + s.Subset(snap["nexus_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "method": "StartNexusOperation", "outcome": tc.outcome}) + s.Contains(snap["nexus_requests"][0].Tags, "nexus_endpoint") + s.Equal(int64(1), snap["nexus_requests"][0].Value) + s.Equal(metrics.MetricUnit(""), snap["nexus_requests"][0].Unit) - require.Len(t, snap["nexus_latency"], 1) - require.Subset(t, snap["nexus_latency"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "method": "StartNexusOperation", "outcome": tc.outcome}) - require.Contains(t, snap["nexus_latency"][0].Tags, "nexus_endpoint") + s.Len(snap["nexus_latency"], 1) + s.Subset(snap["nexus_latency"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "method": "StartNexusOperation", "outcome": tc.outcome}) + s.Contains(snap["nexus_latency"][0].Tags, "nexus_endpoint") // Ensure that StartOperation request is tracked as part of normal service telemetry metrics - require.Condition(t, func() bool { + s.Condition(func() bool { for _, m := range snap["service_requests"] { if opTag, ok := m.Tags["operation"]; ok && opTag == "StartNexusOperation" { return true @@ -336,38 +340,33 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { } for _, tc := range testCases { - s.T().Run(tc.name, func(t *testing.T) { + s.Run(tc.name, func(s *NexusApiTestSuite) { if !tc.onlyByEndpoint { - t.Run("ByNamespaceAndTaskQueue", func(t *testing.T) { - testFn(t, tc, getDispatchByNsAndTqURL(s.HttpAPIAddress(), s.Namespace().String(), tc.endpoint.Spec.Target.GetWorker().TaskQueue)) - }) + s.Run("ByNamespaceAndTaskQueue", func(s *NexusApiTestSuite) { testFn(s, tc, false) }) } - t.Run("ByEndpoint", func(t *testing.T) { - testFn(t, tc, getDispatchByEndpointURL(s.HttpAPIAddress(), tc.endpoint.Id)) - }) + s.Run("ByEndpoint", func(s *NexusApiTestSuite) { testFn(s, tc, true) }) }) } } -func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { +func (s *NexusApiTestSuite) TestNexusStartOperation_Claims(useTemporalFailures bool) { taskQueue := testcore.RandomizeStr("task-queue") - testEndpoint := s.createNexusEndpoint(testcore.RandomizeStr("test-endpoint"), taskQueue) type testcase struct { name string header nexus.Header handler nexusTaskHandler - assertion func(*testing.T, *nexusrpc.ClientStartOperationResponse[string], error, map[string][]*metricstest.CapturedRecording) + assertion func(*NexusApiTestSuite, *nexusrpc.ClientStartOperationResponse[string], error, map[string][]*metricstest.CapturedRecording) } testCases := []testcase{ { name: "no header", - assertion: func(t *testing.T, res *nexusrpc.ClientStartOperationResponse[string], err error, snap map[string][]*metricstest.CapturedRecording) { + assertion: func(s *NexusApiTestSuite, res *nexusrpc.ClientStartOperationResponse[string], err error, snap map[string][]*metricstest.CapturedRecording) { var handlerErr *nexus.HandlerError - require.ErrorAs(t, err, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied", handlerErr.Message) - require.Empty(t, snap["nexus_request_preprocess_errors"]) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) + s.Equal("permission denied", handlerErr.Message) + s.Empty(snap["nexus_request_preprocess_errors"]) }, }, { @@ -375,12 +374,12 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { header: nexus.Header{ "authorization": "Bearer invalid", }, - assertion: func(t *testing.T, res *nexusrpc.ClientStartOperationResponse[string], err error, snap map[string][]*metricstest.CapturedRecording) { + assertion: func(s *NexusApiTestSuite, res *nexusrpc.ClientStartOperationResponse[string], err error, snap map[string][]*metricstest.CapturedRecording) { var handlerErr *nexus.HandlerError - require.ErrorAs(t, err, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeUnauthenticated, handlerErr.Type) - require.Equal(t, "unauthorized", handlerErr.Message) - require.Len(t, snap["nexus_request_preprocess_errors"], 1) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeUnauthenticated, handlerErr.Type) + s.Equal("unauthorized", handlerErr.Message) + s.Len(snap["nexus_request_preprocess_errors"], 1) }, }, { @@ -389,34 +388,42 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { "authorization": "Bearer test", }, handler: nexusEchoHandler, - assertion: func(t *testing.T, res *nexusrpc.ClientStartOperationResponse[string], err error, snap map[string][]*metricstest.CapturedRecording) { - require.NoError(t, err) - require.Equal(t, "input", res.Successful) - require.Empty(t, snap["nexus_request_preprocess_errors"]) + assertion: func(s *NexusApiTestSuite, res *nexusrpc.ClientStartOperationResponse[string], err error, snap map[string][]*metricstest.CapturedRecording) { + s.NoError(err) + s.Equal("input", res.Successful) + s.Empty(snap["nexus_request_preprocess_errors"]) }, }, } - s.GetTestCluster().Host().SetOnAuthorize(func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { - if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName && (c == nil || c.Subject != "test") { - return authorization.Result{Decision: authorization.DecisionDeny}, nil - } - if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName && (c == nil || c.Subject != "test") { - return authorization.Result{Decision: authorization.DecisionDeny}, nil - } - return authorization.Result{Decision: authorization.DecisionAllow}, nil - }) - defer s.GetTestCluster().Host().SetOnAuthorize(nil) - - s.GetTestCluster().Host().SetOnGetClaims(func(ai *authorization.AuthInfo) (*authorization.Claims, error) { - if ai.AuthToken != "Bearer test" { - return nil, errors.New("invalid auth token") + testFn := func(s *NexusApiTestSuite, tc testcase, dispatchOnlyByEndpoint bool) { + env := newNexusTestEnv(s.T(), useTemporalFailures, testcore.WithDedicatedCluster()) + env.GetTestCluster().Host().SetOnAuthorize(func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { + if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName && (c == nil || c.Subject != "test") { + return authorization.Result{Decision: authorization.DecisionDeny}, nil + } + if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName && (c == nil || c.Subject != "test") { + return authorization.Result{Decision: authorization.DecisionDeny}, nil + } + return authorization.Result{Decision: authorization.DecisionAllow}, nil + }) + defer env.GetTestCluster().Host().SetOnAuthorize(nil) + env.GetTestCluster().Host().SetOnGetClaims(func(ai *authorization.AuthInfo) (*authorization.Claims, error) { + if ai.AuthToken != "Bearer test" { + return nil, errors.New("invalid auth token") + } + return &authorization.Claims{Subject: "test"}, nil + }) + defer env.GetTestCluster().Host().SetOnGetClaims(nil) + + testEndpoint := env.createNexusEndpoint(s.T(), testcore.RandomizeStr("test-endpoint"), taskQueue) + var dispatchURL string + if dispatchOnlyByEndpoint { + dispatchURL = getDispatchByEndpointURL(env.HttpAPIAddress(), testEndpoint.Id) + } else { + dispatchURL = getDispatchByNsAndTqURL(env.HttpAPIAddress(), env.Namespace().String(), taskQueue) } - return &authorization.Claims{Subject: "test"}, nil - }) - defer s.GetTestCluster().Host().SetOnGetClaims(nil) - testFn := func(t *testing.T, tc testcase, dispatchURL string) { ctx, cancel := context.WithCancel(testcore.NewContext()) defer cancel() @@ -426,71 +433,67 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { var pollerErrCh <-chan error if tc.handler != nil { // only set on valid request - pollerErrCh = s.nexusTaskPoller(ctx, taskQueue, tc.handler) + pollerErrCh = env.nexusTaskPoller(ctx, s.T(), taskQueue, tc.handler) } - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() result, err := nexusrpc.StartOperation(ctx, client, op, "input", nexus.StartOperationOptions{ Header: tc.header, }) snap := capture.Snapshot() - s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - tc.assertion(t, result, err, snap) + tc.assertion(s, result, err, snap) if pollerErrCh != nil { s.NoError(<-pollerErrCh) } } for _, tc := range testCases { - s.T().Run(tc.name, func(t *testing.T) { - t.Run("ByNamespaceAndTaskQueue", func(t *testing.T) { - testFn(t, tc, getDispatchByNsAndTqURL(s.HttpAPIAddress(), s.Namespace().String(), taskQueue)) - }) - t.Run("ByEndpoint", func(t *testing.T) { - testFn(t, tc, getDispatchByEndpointURL(s.HttpAPIAddress(), testEndpoint.Id)) - }) + s.Run(tc.name, func(s *NexusApiTestSuite) { + s.Run("ByNamespaceAndTaskQueue", func(s *NexusApiTestSuite) { testFn(s, tc, false) }) + s.Run("ByEndpoint", func(s *NexusApiTestSuite) { testFn(s, tc, true) }) }) } } -func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { +func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes(useTemporalFailures bool) { asyncSuccessEndpoint := testcore.RandomizeStr("async-success-endpoint") type testcase struct { outcome string onlyByEndpoint bool - endpoint *nexuspb.Endpoint + endpointName string timeout time.Duration handler nexusTaskHandler - assertion func(*testing.T, error, http.Header) + assertion func(*NexusApiTestSuite, error, http.Header) } testCases := []testcase{ { outcome: "success", onlyByEndpoint: true, - endpoint: s.createNexusEndpoint(asyncSuccessEndpoint, testcore.RandomizeStr("task-queue")), - handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { - s.Equal(asyncSuccessEndpoint, res.Request.Endpoint) + endpointName: asyncSuccessEndpoint, + handler: func(t *testing.T, res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { + require.Equal(t, asyncSuccessEndpoint, res.Request.Endpoint) // Choose an arbitrary test case to assert that all of the input is delivered to the // poll response. op, ok := res.Request.Variant.(*nexuspb.Request_CancelOperation) - s.True(ok) - s.Equal("test-service", op.CancelOperation.Service) - s.Equal("operation", op.CancelOperation.Operation) - s.Equal("token", op.CancelOperation.OperationToken) - s.Equal("value", res.Request.Header["key"]) + require.True(t, ok) + require.Equal(t, "test-service", op.CancelOperation.Service) + require.Equal(t, "operation", op.CancelOperation.Operation) + require.Equal(t, "token", op.CancelOperation.OperationToken) + require.Equal(t, "value", res.Request.Header["key"]) return &nexusTaskResponse{CancelResult: new(struct{})}, nil }, - assertion: func(t *testing.T, err error, headers http.Header) { - require.NoError(t, err) + assertion: func(s *NexusApiTestSuite, err error, headers http.Header) { + s.NoError(err) }, }, { - outcome: "handler_error:INTERNAL", - endpoint: s.createNexusEndpoint(testcore.RandomizeStr("test-endpoint"), testcore.RandomizeStr("task-queue")), - handler: func(_ *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { + outcome: "handler_error:INTERNAL", + endpointName: testcore.RandomizeStr("test-endpoint"), + handler: func(_ *testing.T, _ *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { return nil, &nexus.HandlerError{ Type: nexus.HandlerErrorTypeInternal, Cause: &nexus.FailureError{ @@ -498,38 +501,46 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { }, } }, - assertion: func(t *testing.T, err error, headers http.Header) { + assertion: func(s *NexusApiTestSuite, err error, headers http.Header) { var handlerErr *nexus.HandlerError - require.ErrorAs(t, err, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) - require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) - require.Empty(t, handlerErr.Message) - require.Error(t, handlerErr.Cause) - require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeInternal, handlerErr.Type) + s.Equal("worker", headers.Get("Temporal-Nexus-Failure-Source")) + s.Empty(handlerErr.Message) + s.Error(handlerErr.Cause) + s.Equal("deliberate internal failure", handlerErr.Cause.Error()) }, }, { - outcome: "handler_timeout", - endpoint: s.createNexusEndpoint(testcore.RandomizeStr("test-service"), testcore.RandomizeStr("task-queue")), - timeout: 2 * time.Second, - handler: func(res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { + outcome: "handler_timeout", + endpointName: testcore.RandomizeStr("test-service"), + timeout: 2 * time.Second, + handler: func(t *testing.T, res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { timeoutStr, set := res.Request.Header[nexus.HeaderRequestTimeout] - s.True(set) + require.True(t, set) timeout, err := time.ParseDuration(timeoutStr) - s.NoError(err) + require.NoError(t, err) time.Sleep(timeout) //nolint:forbidigo // Allow time.Sleep for timeout tests return nil, nil }, - assertion: func(t *testing.T, err error, headers http.Header) { + assertion: func(s *NexusApiTestSuite, err error, headers http.Header) { var handlerErr *nexus.HandlerError - require.ErrorAs(t, err, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeUpstreamTimeout, handlerErr.Type) - require.Equal(t, "upstream timeout", handlerErr.Message) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeUpstreamTimeout, handlerErr.Type) + s.Equal("upstream timeout", handlerErr.Message) }, }, } - testFn := func(t *testing.T, tc testcase, dispatchURL string) { + testFn := func(s *NexusApiTestSuite, tc testcase, dispatchOnlyByEndpoint bool) { + env := newNexusTestEnv(s.T(), useTemporalFailures, testcore.WithDedicatedCluster()) + endpoint := env.createNexusEndpoint(s.T(), tc.endpointName, testcore.RandomizeStr("task-queue")) + var dispatchURL string + if dispatchOnlyByEndpoint { + dispatchURL = getDispatchByEndpointURL(env.HttpAPIAddress(), endpoint.Id) + } else { + dispatchURL = getDispatchByNsAndTqURL(env.HttpAPIAddress(), env.Namespace().String(), endpoint.Spec.Target.GetWorker().TaskQueue) + } ctx, cancel := context.WithCancel(testcore.NewContext()) defer cancel() @@ -539,14 +550,14 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { Service: "test-service", HTTPCaller: httpCaller, }) - require.NoError(t, err) - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + s.NoError(err) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - pollerErrCh := s.nexusTaskPoller(ctx, tc.endpoint.Spec.Target.GetWorker().TaskQueue, tc.handler) + pollerErrCh := env.nexusTaskPoller(ctx, s.T(), endpoint.Spec.Target.GetWorker().TaskQueue, tc.handler) handle, err := client.NewOperationHandle("operation", "token") - require.NoError(t, err) + s.NoError(err) eventuallyTick := 500 * time.Millisecond header := nexus.Header{"key": "value"} @@ -562,23 +573,23 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { return err == nil || !(errors.As(err, &handlerErr) && handlerErr.Type == nexus.HandlerErrorTypeNotFound) }, 10*time.Second, eventuallyTick) - tc.assertion(t, err, headerCapture.lastHeaders) + tc.assertion(s, err, headerCapture.lastHeaders) s.NoError(<-pollerErrCh) snap := capture.Snapshot() - require.Len(t, snap["nexus_requests"], 1) - require.Subset(t, snap["nexus_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "method": "CancelNexusOperation", "outcome": tc.outcome}) - require.Contains(t, snap["nexus_requests"][0].Tags, "nexus_endpoint") - require.Equal(t, int64(1), snap["nexus_requests"][0].Value) - require.Equal(t, metrics.MetricUnit(""), snap["nexus_requests"][0].Unit) + s.Len(snap["nexus_requests"], 1) + s.Subset(snap["nexus_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "method": "CancelNexusOperation", "outcome": tc.outcome}) + s.Contains(snap["nexus_requests"][0].Tags, "nexus_endpoint") + s.Equal(int64(1), snap["nexus_requests"][0].Value) + s.Equal(metrics.MetricUnit(""), snap["nexus_requests"][0].Unit) - require.Len(t, snap["nexus_latency"], 1) - require.Subset(t, snap["nexus_latency"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "method": "CancelNexusOperation", "outcome": tc.outcome}) - require.Contains(t, snap["nexus_latency"][0].Tags, "nexus_endpoint") + s.Len(snap["nexus_latency"], 1) + s.Subset(snap["nexus_latency"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "method": "CancelNexusOperation", "outcome": tc.outcome}) + s.Contains(snap["nexus_latency"][0].Tags, "nexus_endpoint") // Ensure that CancelOperation request is tracked as part of normal service telemetry metrics - require.Condition(t, func() bool { + s.Condition(func() bool { for _, m := range snap["service_requests"] { if opTag, ok := m.Tags["operation"]; ok && opTag == "CancelNexusOperation" { return true @@ -589,39 +600,37 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { } for _, tc := range testCases { - s.T().Run(tc.outcome, func(t *testing.T) { + s.Run(tc.outcome, func(s *NexusApiTestSuite) { if !tc.onlyByEndpoint { - t.Run("ByNamespaceAndTaskQueue", func(t *testing.T) { - testFn(t, tc, getDispatchByNsAndTqURL(s.HttpAPIAddress(), s.Namespace().String(), tc.endpoint.Spec.Target.GetWorker().TaskQueue)) - }) + s.Run("ByNamespaceAndTaskQueue", func(s *NexusApiTestSuite) { testFn(s, tc, false) }) } - t.Run("ByEndpoint", func(t *testing.T) { - testFn(t, tc, getDispatchByEndpointURL(s.HttpAPIAddress(), tc.endpoint.Id)) - }) + s.Run("ByEndpoint", func(s *NexusApiTestSuite) { testFn(s, tc, true) }) }) } } -func (s *NexusApiTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_SupportsVersioning() { +func (s *NexusApiTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_SupportsVersioning(useTemporalFailures bool) { + env := newNexusTestEnv(s.T(), useTemporalFailures, testcore.WithDedicatedCluster()) + env.OverrideDynamicConfig(dynamicconfig.FrontendEnableWorkerVersioningRuleAPIs, true) ctx, cancel := context.WithCancel(testcore.NewContext()) defer cancel() taskQueue := testcore.RandomizeStr("task-queue") - err := s.SdkClient().UpdateWorkerBuildIdCompatibility(ctx, &sdkclient.UpdateWorkerBuildIdCompatibilityOptions{ + err := env.SdkClient().UpdateWorkerBuildIdCompatibility(ctx, &sdkclient.UpdateWorkerBuildIdCompatibilityOptions{ //nolint:staticcheck // SA1019 deprecated TaskQueue: taskQueue, Operation: &sdkclient.BuildIDOpAddNewIDInNewDefaultSet{BuildID: "old-build-id"}, }) s.NoError(err) - err = s.SdkClient().UpdateWorkerBuildIdCompatibility(ctx, &sdkclient.UpdateWorkerBuildIdCompatibilityOptions{ + err = env.SdkClient().UpdateWorkerBuildIdCompatibility(ctx, &sdkclient.UpdateWorkerBuildIdCompatibilityOptions{ //nolint:staticcheck // SA1019 deprecated TaskQueue: taskQueue, Operation: &sdkclient.BuildIDOpAddNewIDInNewDefaultSet{BuildID: "new-build-id"}, }) s.NoError(err) - u := getDispatchByNsAndTqURL(s.HttpAPIAddress(), s.Namespace().String(), taskQueue) + u := getDispatchByNsAndTqURL(env.HttpAPIAddress(), env.Namespace().String(), taskQueue) client, err := nexusrpc.NewHTTPClient(nexusrpc.HTTPClientOptions{BaseURL: u, Service: "test-service"}) s.NoError(err) // Versioned poller gets task - pollerErrCh1 := s.versionedNexusTaskPoller(ctx, taskQueue, "new-build-id", nexusEchoHandler) + pollerErrCh1 := env.versionedNexusTaskPoller(ctx, s.T(), taskQueue, "new-build-id", nexusEchoHandler) result, err := nexusrpc.StartOperation(ctx, client, op, "input", nexus.StartOperationOptions{}) s.NoError(err) @@ -629,9 +638,9 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_Su s.NoError(<-pollerErrCh1) // Unversioned poller doesn't get a task - pollerErrCh2 := s.nexusTaskPoller(ctx, taskQueue, nexusEchoHandler) + pollerErrCh2 := env.nexusTaskPoller(ctx, s.T(), taskQueue, nexusEchoHandler) // Versioned poller gets task with wrong build ID - pollerErrCh3 := s.versionedNexusTaskPoller(ctx, taskQueue, "old-build-id", nexusEchoHandler) + pollerErrCh3 := env.versionedNexusTaskPoller(ctx, s.T(), taskQueue, "old-build-id", nexusEchoHandler) timeoutCtx, timeoutCancel := context.WithTimeout(ctx, time.Second*2) defer timeoutCancel() @@ -651,13 +660,14 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_Su // TestNexusClientNameMetricPropagation verifies that when an SDK worker polls for Nexus tasks // with client-name in gRPC metadata, the matching service emits nexus_task_requests with a // client_name tag. This proves the header propagates e2e: SDK → frontend → matching. -func (s *NexusApiTestSuite) TestNexusClientNameMetricPropagation() { +func (s *NexusApiTestSuite) TestNexusClientNameMetricPropagation(useTemporalFailures bool) { + env := newNexusTestEnv(s.T(), useTemporalFailures, testcore.WithDedicatedCluster()) const expectedClientName = "temporal-go" taskQueue := testcore.RandomizeStr("tq") - endpoint := s.createNexusEndpoint(testcore.RandomizeStr("endpoint"), taskQueue) + endpoint := env.createNexusEndpoint(s.T(), testcore.RandomizeStr("endpoint"), taskQueue) - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) ctx, cancel := context.WithCancel(testcore.NewContext()) defer cancel() @@ -671,11 +681,11 @@ func (s *NexusApiTestSuite) TestNexusClientNameMetricPropagation() { "supported-server-versions", headers.SupportedServerVersions, "supported-features", headers.AllFeatures, )) - pollerErrCh := s.nexusTaskPoller(pollerCtx, taskQueue, nexusEchoHandler) + pollerErrCh := env.nexusTaskPoller(pollerCtx, s.T(), taskQueue, nexusEchoHandler) // Trigger a Nexus start operation via HTTP to unblock the poller. client, err := nexusrpc.NewHTTPClient(nexusrpc.HTTPClientOptions{ - BaseURL: getDispatchByEndpointURL(s.HttpAPIAddress(), endpoint.Id), + BaseURL: getDispatchByEndpointURL(env.HttpAPIAddress(), endpoint.Id), Service: "test-service", }) s.NoError(err) @@ -701,7 +711,7 @@ func (s *NexusApiTestSuite) TestNexusClientNameMetricPropagation() { expectedClientName, snap["nexus_task_requests"]) } -func nexusEchoHandler(res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { +func nexusEchoHandler(_ *testing.T, res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { return &nexusTaskResponse{StartResult: &nexus.HandlerStartOperationResultSync[*commonpb.Payload]{Value: res.Request.GetStartOperation().GetPayload()}}, nil } diff --git a/tests/nexus_api_validation_test.go b/tests/nexus_api_validation_test.go index 523b4b44d41..bcd26ea96e5 100644 --- a/tests/nexus_api_validation_test.go +++ b/tests/nexus_api_validation_test.go @@ -10,8 +10,6 @@ import ( "github.com/google/uuid" "github.com/nexus-rpc/sdk-go/nexus" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" nexuspb "go.temporal.io/api/nexus/v1" "go.temporal.io/api/serviceerror" "go.temporal.io/api/workflowservice/v1" @@ -19,29 +17,30 @@ import ( "go.temporal.io/server/common/authorization" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/nexus/nexusrpc" + "go.temporal.io/server/common/testing/parallelsuite" "go.temporal.io/server/service/frontend/configs" "go.temporal.io/server/tests/testcore" ) type NexusAPIValidationTestSuite struct { - NexusTestBaseSuite + parallelsuite.Suite[*NexusAPIValidationTestSuite] } func TestNexusAPIValidationTestSuite(t *testing.T) { - t.Parallel() - suite.Run(t, new(NexusAPIValidationTestSuite)) + parallelsuite.Run(t, &NexusAPIValidationTestSuite{}) } func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_NamespaceNotFound() { + env := newNexusTestEnv(s.T(), false, testcore.WithDedicatedCluster()) // Also use this test to verify that namespaces are unescaped in the path. taskQueue := testcore.RandomizeStr("task-queue") namespace := "namespace not/found" - u := getDispatchByNsAndTqURL(s.HttpAPIAddress(), namespace, taskQueue) + u := getDispatchByNsAndTqURL(env.HttpAPIAddress(), namespace, taskQueue) client, err := nexusrpc.NewHTTPClient(nexusrpc.HTTPClientOptions{BaseURL: u, Service: "test-service"}) s.NoError(err) ctx := testcore.NewContext() - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) _, err = nexusrpc.StartOperation(ctx, client, op, "input", nexus.StartOperationOptions{}) var handlerError *nexus.HandlerError s.ErrorAs(err, &handlerError) @@ -56,6 +55,7 @@ func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_WithNamespaceAndTa } func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_NamespaceTooLong() { + env := newNexusTestEnv(s.T(), false, testcore.WithDedicatedCluster()) taskQueue := testcore.RandomizeStr("task-queue") var namespace string @@ -63,12 +63,12 @@ func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_WithNamespaceAndTa namespace += "namespace-is-a-very-long-string" } - u := getDispatchByNsAndTqURL(s.HttpAPIAddress(), namespace, taskQueue) + u := getDispatchByNsAndTqURL(env.HttpAPIAddress(), namespace, taskQueue) client, err := nexusrpc.NewHTTPClient(nexusrpc.HTTPClientOptions{BaseURL: u, Service: "test-service"}) s.NoError(err) ctx := testcore.NewContext() - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) _, err = nexusrpc.StartOperation(ctx, client, op, "input", nexus.StartOperationOptions{}) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) @@ -82,112 +82,131 @@ func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_WithNamespaceAndTa } func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_Forbidden() { - taskQueue := testcore.RandomizeStr("task-queue") - testEndpoint := s.createNexusEndpoint(testcore.RandomizeStr("test-endpoint"), taskQueue) - type testcase struct { name string - onAuthorize func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error) - checkFailure func(t *testing.T, handlerErr *nexus.HandlerError) + onAuthorize func(endpointName string) func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error) + checkFailure func(s *NexusAPIValidationTestSuite, handlerErr *nexus.HandlerError) exposeAuthorizerErrors bool expectedOutcomeMetric string } testCases := []testcase{ { name: "deny with reason", - onAuthorize: func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { - if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { - return authorization.Result{Decision: authorization.DecisionDeny, Reason: "unauthorized in test"}, nil - } - if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { - if ct.NexusEndpointName != testEndpoint.Spec.Name { - panic("expected nexus endpoint name") + onAuthorize: func(endpointName string) func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error) { + return func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { + if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { + return authorization.Result{Decision: authorization.DecisionDeny, Reason: "unauthorized in test"}, nil } - return authorization.Result{Decision: authorization.DecisionDeny, Reason: "unauthorized in test"}, nil + if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { + if ct.NexusEndpointName != endpointName { + panic("expected nexus endpoint name") + } + return authorization.Result{Decision: authorization.DecisionDeny, Reason: "unauthorized in test"}, nil + } + return authorization.Result{Decision: authorization.DecisionAllow}, nil } - return authorization.Result{Decision: authorization.DecisionAllow}, nil }, - checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { - require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied: unauthorized in test", handlerErr.Message) + checkFailure: func(s *NexusAPIValidationTestSuite, handlerErr *nexus.HandlerError) { + s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) + s.Equal("permission denied: unauthorized in test", handlerErr.Message) }, expectedOutcomeMetric: "unauthorized", exposeAuthorizerErrors: false, }, { name: "deny without reason", - onAuthorize: func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { - if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { - return authorization.Result{Decision: authorization.DecisionDeny}, nil - } - if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { - if ct.NexusEndpointName != testEndpoint.Spec.Name { - panic("expected nexus endpoint name") + onAuthorize: func(endpointName string) func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error) { + return func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { + if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { + return authorization.Result{Decision: authorization.DecisionDeny}, nil + } + if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { + if ct.NexusEndpointName != endpointName { + panic("expected nexus endpoint name") + } + return authorization.Result{Decision: authorization.DecisionDeny}, nil } - return authorization.Result{Decision: authorization.DecisionDeny}, nil + return authorization.Result{Decision: authorization.DecisionAllow}, nil } - return authorization.Result{Decision: authorization.DecisionAllow}, nil }, - checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { - require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied", handlerErr.Message) + checkFailure: func(s *NexusAPIValidationTestSuite, handlerErr *nexus.HandlerError) { + s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) + s.Equal("permission denied", handlerErr.Message) }, expectedOutcomeMetric: "unauthorized", exposeAuthorizerErrors: false, }, { name: "deny with generic error", - onAuthorize: func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { - if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { - return authorization.Result{}, errors.New("some generic error") - } - if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { - if ct.NexusEndpointName != testEndpoint.Spec.Name { - panic("expected nexus endpoint name") + onAuthorize: func(endpointName string) func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error) { + return func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { + if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { + return authorization.Result{}, errors.New("some generic error") + } + if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { + if ct.NexusEndpointName != endpointName { + panic("expected nexus endpoint name") + } + return authorization.Result{}, errors.New("some generic error") } - return authorization.Result{}, errors.New("some generic error") + return authorization.Result{Decision: authorization.DecisionAllow}, nil } - return authorization.Result{Decision: authorization.DecisionAllow}, nil }, - checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { - require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied", handlerErr.Message) + checkFailure: func(s *NexusAPIValidationTestSuite, handlerErr *nexus.HandlerError) { + s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) + s.Equal("permission denied", handlerErr.Message) }, expectedOutcomeMetric: "unauthorized", exposeAuthorizerErrors: false, }, { name: "deny with exposed error", - onAuthorize: func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { - if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { - return authorization.Result{}, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") - } - if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { - if ct.NexusEndpointName != testEndpoint.Spec.Name { - panic("expected nexus endpoint name") + onAuthorize: func(endpointName string) func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error) { + return func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { + if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { + return authorization.Result{}, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") } - return authorization.Result{}, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") + if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { + if ct.NexusEndpointName != endpointName { + panic("expected nexus endpoint name") + } + return authorization.Result{}, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") + } + return authorization.Result{Decision: authorization.DecisionAllow}, nil } - return authorization.Result{Decision: authorization.DecisionAllow}, nil }, - checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { - require.Equal(t, nexus.HandlerErrorTypeUnavailable, handlerErr.Type) - require.Equal(t, "exposed error", handlerErr.Message) + checkFailure: func(s *NexusAPIValidationTestSuite, handlerErr *nexus.HandlerError) { + s.Equal(nexus.HandlerErrorTypeUnavailable, handlerErr.Type) + s.Equal("exposed error", handlerErr.Message) }, expectedOutcomeMetric: "internal_auth_error", exposeAuthorizerErrors: true, }, } - testFn := func(t *testing.T, tc testcase, dispatchURL string) { + testFn := func(s *NexusAPIValidationTestSuite, tc testcase, dispatchOnlyByEndpoint bool) { + env := newNexusTestEnv(s.T(), false, testcore.WithDedicatedCluster()) + taskQueue := testcore.RandomizeStr("task-queue") + testEndpoint := env.createNexusEndpoint(s.T(), testcore.RandomizeStr("test-endpoint"), taskQueue) + + env.GetTestCluster().Host().SetOnAuthorize(tc.onAuthorize(testEndpoint.Spec.Name)) + s.T().Cleanup(func() { env.GetTestCluster().Host().SetOnAuthorize(nil) }) + + env.OverrideDynamicConfig(dynamicconfig.ExposeAuthorizerErrors, tc.exposeAuthorizerErrors) + + var dispatchURL string + if dispatchOnlyByEndpoint { + dispatchURL = getDispatchByEndpointURL(env.HttpAPIAddress(), testEndpoint.Id) + } else { + dispatchURL = getDispatchByNsAndTqURL(env.HttpAPIAddress(), env.Namespace().String(), taskQueue) + } + client, err := nexusrpc.NewHTTPClient(nexusrpc.HTTPClientOptions{BaseURL: dispatchURL, Service: "test-service"}) - require.NoError(t, err) + s.NoError(err) ctx := testcore.NewContext() - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - - s.OverrideDynamicConfig(dynamicconfig.ExposeAuthorizerErrors, tc.exposeAuthorizerErrors) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) // Wait until the endpoint is loaded into the registry. s.Eventually(func() bool { @@ -197,47 +216,48 @@ func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_Forbidden() { }, 10*time.Second, 1*time.Second) var handlerErr *nexus.HandlerError - require.ErrorAs(t, err, &handlerErr) - tc.checkFailure(t, handlerErr) + s.ErrorAs(err, &handlerErr) + tc.checkFailure(s, handlerErr) snap := capture.Snapshot() - require.Len(t, snap["nexus_requests"], 1) - require.Subset(t, snap["nexus_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "method": "StartNexusOperation", "outcome": tc.expectedOutcomeMetric}) - require.Equal(t, int64(1), snap["nexus_requests"][0].Value) + s.Len(snap["nexus_requests"], 1) + s.Subset(snap["nexus_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "method": "StartNexusOperation", "outcome": tc.expectedOutcomeMetric}) + s.Equal(int64(1), snap["nexus_requests"][0].Value) } for _, tc := range testCases { - s.Run(tc.name, func() { - s.GetTestCluster().Host().SetOnAuthorize(tc.onAuthorize) - defer s.GetTestCluster().Host().SetOnAuthorize(nil) - - s.Run("ByNamespaceAndTaskQueue", func() { - testFn(s.T(), tc, getDispatchByNsAndTqURL(s.HttpAPIAddress(), s.Namespace().String(), taskQueue)) - }) - s.Run("ByEndpoint", func() { - testFn(s.T(), tc, getDispatchByEndpointURL(s.HttpAPIAddress(), testEndpoint.Id)) - }) + s.Run(tc.name, func(s *NexusAPIValidationTestSuite) { + s.Run("ByNamespaceAndTaskQueue", func(s *NexusAPIValidationTestSuite) { testFn(s, tc, false) }) + s.Run("ByEndpoint", func(s *NexusAPIValidationTestSuite) { testFn(s, tc, true) }) }) } } func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_PayloadSizeLimit() { - taskQueue := testcore.RandomizeStr("task-queue") - testEndpoint := s.createNexusEndpoint(testcore.RandomizeStr("test-endpoint"), taskQueue) - // Use -10 to avoid hitting MaxNexusAPIRequestBodyBytes. Actual payload will still exceed limit because of // additional Content headers. See common/rpc/grpc.go:66 input := strings.Repeat("a", (2*1024*1024)-10) - testFn := func(t *testing.T, dispatchURL string) { + testFn := func(s *NexusAPIValidationTestSuite, dispatchOnlyByEndpoint bool) { + env := newNexusTestEnv(s.T(), false, testcore.WithDedicatedCluster()) + taskQueue := testcore.RandomizeStr("task-queue") + testEndpoint := env.createNexusEndpoint(s.T(), testcore.RandomizeStr("test-endpoint"), taskQueue) + + var dispatchURL string + if dispatchOnlyByEndpoint { + dispatchURL = getDispatchByEndpointURL(env.HttpAPIAddress(), testEndpoint.Id) + } else { + dispatchURL = getDispatchByNsAndTqURL(env.HttpAPIAddress(), env.Namespace().String(), taskQueue) + } + ctx, cancel := context.WithCancel(testcore.NewContext()) defer cancel() client, err := nexusrpc.NewHTTPClient(nexusrpc.HTTPClientOptions{BaseURL: dispatchURL, Service: "test-service"}) - require.NoError(t, err) - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + s.NoError(err) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) var result *nexusrpc.ClientStartOperationResponse[string] @@ -251,42 +271,39 @@ func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_PayloadSizeLimit() return err == nil || (!errors.As(err, &handlerErr) || handlerErr.Type != nexus.HandlerErrorTypeNotFound) }, 10*time.Second, 500*time.Millisecond) - require.Nil(t, result) + s.Nil(result) var handlerErr *nexus.HandlerError - require.ErrorAs(t, err, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeBadRequest, handlerErr.Type) - require.Equal(t, "input exceeds size limit", handlerErr.Message) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + s.Equal("input exceeds size limit", handlerErr.Message) } - s.Run("ByNamespaceAndTaskQueue", func() { - testFn(s.T(), getDispatchByNsAndTqURL(s.HttpAPIAddress(), s.Namespace().String(), taskQueue)) - }) - s.Run("ByEndpoint", func() { - testFn(s.T(), getDispatchByEndpointURL(s.HttpAPIAddress(), testEndpoint.Id)) - }) + s.Run("ByNamespaceAndTaskQueue", func(s *NexusAPIValidationTestSuite) { testFn(s, false) }) + s.Run("ByEndpoint", func(s *NexusAPIValidationTestSuite) { testFn(s, true) }) } func (s *NexusAPIValidationTestSuite) TestNexus_RespondNexusTaskMethods_VerifiesTaskTokenMatchesRequestNamespace() { + env := newNexusTestEnv(s.T(), false, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() tt := tokenspb.NexusTask{ - NamespaceId: s.NamespaceID().String(), + NamespaceId: env.NamespaceID().String(), TaskQueue: "test", TaskId: uuid.NewString(), } ttBytes, err := tt.Marshal() s.NoError(err) - _, err = s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ - Namespace: s.ExternalNamespace().String(), + _, err = env.FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: env.ExternalNamespace().String(), Identity: uuid.NewString(), TaskToken: ttBytes, Response: &nexuspb.Response{}, }) s.ErrorContains(err, "Operation requested with a token from a different namespace.") - _, err = s.GetTestCluster().FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{ - Namespace: s.ExternalNamespace().String(), + _, err = env.FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{ + Namespace: env.ExternalNamespace().String(), Identity: uuid.NewString(), TaskToken: ttBytes, Error: &nexuspb.HandlerError{}, @@ -295,18 +312,19 @@ func (s *NexusAPIValidationTestSuite) TestNexus_RespondNexusTaskMethods_Verifies } func (s *NexusAPIValidationTestSuite) TestNexus_RespondNexusTaskCompleted_ValidateOperationTokenLength() { + env := newNexusTestEnv(s.T(), false, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() tt := tokenspb.NexusTask{ - NamespaceId: s.NamespaceID().String(), + NamespaceId: env.NamespaceID().String(), TaskQueue: "test", TaskId: uuid.NewString(), } ttBytes, err := tt.Marshal() s.NoError(err) - _, err = s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ - Namespace: s.Namespace().String(), + _, err = env.FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskToken: ttBytes, Response: &nexuspb.Response{ @@ -327,18 +345,19 @@ func (s *NexusAPIValidationTestSuite) TestNexus_RespondNexusTaskCompleted_Valida } func (s *NexusAPIValidationTestSuite) TestNexus_RespondNexusTaskMethods_ValidateFailureDetailsJSON() { + env := newNexusTestEnv(s.T(), false, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() tt := tokenspb.NexusTask{ - NamespaceId: s.NamespaceID().String(), + NamespaceId: env.NamespaceID().String(), TaskQueue: "test", TaskId: uuid.NewString(), } ttBytes, err := tt.Marshal() s.NoError(err) - _, err = s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ - Namespace: s.Namespace().String(), + _, err = env.FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskToken: ttBytes, Response: &nexuspb.Response{ @@ -360,8 +379,8 @@ func (s *NexusAPIValidationTestSuite) TestNexus_RespondNexusTaskMethods_Validate s.ErrorAs(err, &invalidArgumentErr) s.Equal("failure details must be JSON serializable", invalidArgumentErr.Message) - _, err = s.GetTestCluster().FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{ - Namespace: s.Namespace().String(), + _, err = env.FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskToken: ttBytes, Error: &nexuspb.HandlerError{ @@ -375,12 +394,13 @@ func (s *NexusAPIValidationTestSuite) TestNexus_RespondNexusTaskMethods_Validate } func (s *NexusAPIValidationTestSuite) TestNexusStartOperation_ByEndpoint_EndpointNotFound() { - u := getDispatchByEndpointURL(s.HttpAPIAddress(), uuid.NewString()) + env := newNexusTestEnv(s.T(), false, testcore.WithDedicatedCluster()) + u := getDispatchByEndpointURL(env.HttpAPIAddress(), uuid.NewString()) client, err := nexusrpc.NewHTTPClient(nexusrpc.HTTPClientOptions{BaseURL: u, Service: "test-service"}) s.NoError(err) ctx := testcore.NewContext() - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) _, err = nexusrpc.StartOperation(ctx, client, op, "input", nexus.StartOperationOptions{}) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) diff --git a/tests/nexus_test_base.go b/tests/nexus_test_base.go index 0cff4360d81..4c6aefc57ff 100644 --- a/tests/nexus_test_base.go +++ b/tests/nexus_test_base.go @@ -3,9 +3,11 @@ package tests import ( "context" "errors" + "testing" "github.com/google/uuid" "github.com/nexus-rpc/sdk-go/nexus" + "github.com/stretchr/testify/require" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" nexuspb "go.temporal.io/api/nexus/v1" @@ -18,26 +20,33 @@ import ( "go.temporal.io/server/tests/testcore" ) -type NexusTestBaseSuite struct { - testcore.FunctionalTestBase +type NexusTestEnv struct { + *testcore.TestEnv useTemporalFailures bool } -func (s *NexusTestBaseSuite) createNexusEndpoint(name string, taskQueue string) *nexuspb.Endpoint { - resp, err := s.OperatorClient().CreateNexusEndpoint(testcore.NewContext(), &operatorservice.CreateNexusEndpointRequest{ +func newNexusTestEnv(t *testing.T, useTemporalFailures bool, opts ...testcore.TestOption) *NexusTestEnv { + return &NexusTestEnv{ + TestEnv: testcore.NewEnv(t, opts...), + useTemporalFailures: useTemporalFailures, + } +} + +func (env *NexusTestEnv) createNexusEndpoint(t *testing.T, name string, taskQueue string) *nexuspb.Endpoint { + resp, err := env.OperatorClient().CreateNexusEndpoint(testcore.NewContext(), &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: name, Target: &nexuspb.EndpointTarget{ Variant: &nexuspb.EndpointTarget_Worker_{ Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), TaskQueue: taskQueue, }, }, }, }, }) - s.NoError(err) + require.NoError(t, err) return resp.Endpoint } @@ -54,21 +63,21 @@ type nexusTaskResponse struct { Links []nexus.Link } -type nexusTaskHandler func(res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) +type nexusTaskHandler func(t *testing.T, res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) -func (s *NexusTestBaseSuite) nexusTaskPoller(ctx context.Context, taskQueue string, handler nexusTaskHandler) <-chan error { - return s.versionedNexusTaskPoller(ctx, taskQueue, "", handler) +func (env *NexusTestEnv) nexusTaskPoller(ctx context.Context, t *testing.T, taskQueue string, handler nexusTaskHandler) <-chan error { + return env.versionedNexusTaskPoller(ctx, t, taskQueue, "", handler) } -func (s *NexusTestBaseSuite) versionedNexusTaskPoller(ctx context.Context, taskQueue, buildID string, handler nexusTaskHandler) <-chan error { +func (env *NexusTestEnv) versionedNexusTaskPoller(ctx context.Context, t *testing.T, taskQueue, buildID string, handler nexusTaskHandler) <-chan error { errCh := make(chan error, 1) go func() { - errCh <- s.versionedNexusTaskPollerDo(ctx, taskQueue, buildID, handler) + errCh <- env.versionedNexusTaskPollerDo(ctx, t, taskQueue, buildID, handler) }() return errCh } -func (s *NexusTestBaseSuite) versionedNexusTaskPollerDo(ctx context.Context, taskQueue, buildID string, handler nexusTaskHandler) error { +func (env *NexusTestEnv) versionedNexusTaskPollerDo(ctx context.Context, t *testing.T, taskQueue, buildID string, handler nexusTaskHandler) error { var vc *commonpb.WorkerVersionCapabilities if buildID != "" { vc = &commonpb.WorkerVersionCapabilities{ @@ -76,8 +85,8 @@ func (s *NexusTestBaseSuite) versionedNexusTaskPollerDo(ctx context.Context, tas UseVersioning: true, } } - res, err := s.GetTestCluster().FrontendClient().PollNexusTaskQueue(ctx, &workflowservice.PollNexusTaskQueueRequest{ - Namespace: s.Namespace().String(), + res, err := env.FrontendClient().PollNexusTaskQueue(ctx, &workflowservice.PollNexusTaskQueueRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, @@ -98,14 +107,14 @@ func (s *NexusTestBaseSuite) versionedNexusTaskPollerDo(ctx context.Context, tas if res.Request.GetStartOperation().GetService() != "test-service" && res.Request.GetCancelOperation().GetService() != "test-service" { return errors.New("expected service to be test-service") } - result, handlerErr := handler(res) + result, handlerErr := handler(t, res) if handlerErr != nil { var opErr *nexus.OperationError var he *nexus.HandlerError if errors.As(handlerErr, &opErr) { - return s.respondNexusTaskCompletedWithOperationError(ctx, res.TaskToken, opErr) + return env.respondNexusTaskCompletedWithOperationError(ctx, res.TaskToken, opErr) } else if errors.As(handlerErr, &he) { - return s.respondNexusTaskFailed(ctx, res.TaskToken, he) + return env.respondNexusTaskFailed(ctx, res.TaskToken, he) } return handlerErr } @@ -163,8 +172,8 @@ func (s *NexusTestBaseSuite) versionedNexusTaskPollerDo(ctx context.Context, tas panic("unreachable") // nolint:revive // all implementations of HandlerStartOperationResult must be covered here, so this should be unreachable. } } - _, err = s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ - Namespace: s.Namespace().String(), + _, err = env.FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskToken: res.TaskToken, Response: response, @@ -175,8 +184,8 @@ func (s *NexusTestBaseSuite) versionedNexusTaskPollerDo(ctx context.Context, tas return nil } -func (s *NexusTestBaseSuite) respondNexusTaskFailed(ctx context.Context, taskToken []byte, he *nexus.HandlerError) error { - if s.useTemporalFailures { +func (env *NexusTestEnv) respondNexusTaskFailed(ctx context.Context, taskToken []byte, he *nexus.HandlerError) error { + if env.useTemporalFailures { nexusFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(he) if err != nil { return err @@ -185,8 +194,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskFailed(ctx context.Context, taskTok if err != nil { return err } - _, err = s.GetTestCluster().FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{ - Namespace: s.Namespace().String(), + _, err = env.FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskToken: taskToken, Failure: temporalFailure, @@ -219,8 +228,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskFailed(ctx context.Context, taskTok protoError.RetryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE default: } - _, err := s.GetTestCluster().FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{ - Namespace: s.Namespace().String(), + _, err := env.FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskToken: taskToken, Error: protoError, @@ -231,8 +240,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskFailed(ctx context.Context, taskTok return nil } -func (s *NexusTestBaseSuite) respondNexusTaskCompletedWithOperationError(ctx context.Context, taskToken []byte, opErr *nexus.OperationError) error { - if s.useTemporalFailures { +func (env *NexusTestEnv) respondNexusTaskCompletedWithOperationError(ctx context.Context, taskToken []byte, opErr *nexus.OperationError) error { + if env.useTemporalFailures { nexusFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(opErr) if err != nil { return err @@ -250,8 +259,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskCompletedWithOperationError(ctx con }, }, } - _, err = s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ - Namespace: s.Namespace().String(), + _, err = env.FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskToken: taskToken, Response: response, @@ -284,8 +293,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskCompletedWithOperationError(ctx con }, }, } - _, err := s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ - Namespace: s.Namespace().String(), + _, err := env.FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskToken: taskToken, Response: response, diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index 4acf8dd67fc..4607e9e2285 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -14,7 +14,6 @@ import ( "github.com/nexus-rpc/sdk-go/nexus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" commandpb "go.temporal.io/api/command/v1" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" @@ -40,6 +39,7 @@ import ( commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/common/nexus/nexusrpc" "go.temporal.io/server/common/nexus/nexustest" + "go.temporal.io/server/common/testing/parallelsuite" "go.temporal.io/server/common/testing/protorequire" "go.temporal.io/server/components/nexusoperations" "go.temporal.io/server/service/frontend/configs" @@ -48,19 +48,15 @@ import ( ) type NexusWorkflowTestSuite struct { - NexusTestBaseSuite + parallelsuite.Suite[*NexusWorkflowTestSuite] } func TestNexusWorkflowTestSuite(t *testing.T) { - t.Parallel() - suite.Run(t, &NexusWorkflowTestSuite{ - NexusTestBaseSuite: NexusTestBaseSuite{ - useTemporalFailures: true, - }, - }) + parallelsuite.Run(t, &NexusWorkflowTestSuite{}) } func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) @@ -85,7 +81,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { listenAddr := nexustest.AllocListenAddress() nexustest.NewNexusServer(s.T(), listenAddr, h) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ @@ -99,15 +95,15 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, WorkflowTaskTimeout: time.Second, }, "workflow") s.NoError(err) s.EventuallyWithT(func(t *assert.CollectT) { - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -115,7 +111,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { Identity: "test", }) require.NoError(t, err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -136,8 +132,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { }, time.Second*20, time.Millisecond*200) // Poll and wait for the "started" event to be recorded. - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -157,7 +153,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { }) s.Positive(scheduledEventIdx) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -174,8 +170,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { s.NoError(err) // Poll and verify first cancel request failed and allowed workflow to make progress. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -189,7 +185,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { s.Positive(cancelFailedIdx) // Start new operation to successfully cancel. - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -208,8 +204,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { }) s.NoError(err) // Poll and wait for the "started" event to be recorded. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -226,7 +222,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { } } s.Positive(secondScheduledEventID) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -244,7 +240,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { // Poll and wait for the cancelation request to go through. s.EventuallyWithT(func(t *assert.CollectT) { - desc, err := s.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) + desc, err := env.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) require.NoError(t, err) require.Len(t, desc.PendingNexusOperations, 2) op1 := desc.PendingNexusOperations[0] @@ -263,10 +259,10 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_SUCCEEDED, op2.CancellationInfo.State) }, time.Second*5, time.Millisecond*30) - err = s.SdkClient().TerminateWorkflow(ctx, run.GetID(), run.GetRunID(), "test") + err = env.SdkClient().TerminateWorkflow(ctx, run.GetID(), run.GetRunID(), "test") s.NoError(err) - hist := s.GetHistory(s.Namespace().String(), &commonpb.WorkflowExecution{ + hist := env.GetHistory(env.Namespace().String(), &commonpb.WorkflowExecution{ WorkflowId: run.GetID(), RunId: run.GetRunID(), }) @@ -275,6 +271,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelation() { } func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) @@ -300,7 +297,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion() { listenAddr := nexustest.AllocListenAddress() nexustest.NewNexusServer(s.T(), listenAddr, h) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ @@ -314,14 +311,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion() { }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) s.EventuallyWithT(func(t *assert.CollectT) { - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -329,7 +326,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion() { Identity: "test", }) require.NoError(t, err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -349,8 +346,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion() { require.NoError(t, err) }, time.Second*20, time.Millisecond*200) - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -365,7 +362,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion() { s.Len(pollResp.History.Events[completedEventIdx].GetLinks(), 1) protorequire.ProtoEqual(s.T(), handlerLink, pollResp.History.Events[completedEventIdx].GetLinks()[0].GetWorkflowEvent()) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -391,8 +388,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion() { // Use this test case to verify that the state machine is actually deleted, the workflowservice // DescribeWorkflowExecution API filters out operations in terminal state in case they complete in a server version // without state machine deletion enabled, hence the use of the adminservice API here. - desc, err := s.AdminClient().DescribeMutableState(ctx, &adminservice.DescribeMutableStateRequest{ - Namespace: s.Namespace().String(), + desc, err := env.AdminClient().DescribeMutableState(ctx, &adminservice.DescribeMutableStateRequest{ + Namespace: env.Namespace().String(), Execution: &commonpb.WorkflowExecution{ WorkflowId: run.GetID(), }, @@ -403,6 +400,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion() { } func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion_LargePayload() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) @@ -417,7 +415,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion_LargePayload() listenAddr := nexustest.AllocListenAddress() nexustest.NewNexusServer(s.T(), listenAddr, h) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ @@ -431,14 +429,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion_LargePayload() }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) s.EventuallyWithT(func(t *assert.CollectT) { - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -446,7 +444,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion_LargePayload() Identity: "test", }) require.NoError(t, err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -466,8 +464,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion_LargePayload() require.NoError(t, err) }, time.Second*20, time.Millisecond*200) - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -480,7 +478,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion_LargePayload() }) s.Positive(failedEventIdx) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -506,14 +504,15 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncCompletion_LargePayload() } func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) - testClusterInfo, err := s.FrontendClient().GetClusterInfo(ctx, &workflowservice.GetClusterInfoRequest{}) + testClusterInfo, err := env.FrontendClient().GetClusterInfo(ctx, &workflowservice.GetClusterInfoRequest{}) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) @@ -556,7 +555,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { } s.NotNil(links[0].GetWorkflowEvent()) protorequire.ProtoEqual(s.T(), &commonpb.Link_WorkflowEvent{ - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), WorkflowId: run.GetID(), RunId: run.GetRunID(), Reference: &commonpb.Link_WorkflowEvent_EventRef{ @@ -584,7 +583,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { listenAddr := nexustest.AllocListenAddress() nexustest.NewNexusServer(s.T(), listenAddr, h) - _, err = s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err = env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ @@ -598,8 +597,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { }) s.NoError(err) - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -607,7 +606,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -627,8 +626,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { s.NoError(err) // Poll and verify that the "started" event was recorded. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -636,7 +635,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, }) @@ -663,28 +662,28 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } s.NoError(err) - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, largeCompletion) + snap, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, largeCompletion) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "error_bad_request"}) invalidNamespace := testcore.RandomizeStr("ns") - _, err = s.FrontendClient().RegisterNamespace(ctx, &workflowservice.RegisterNamespaceRequest{ + _, err = env.FrontendClient().RegisterNamespace(ctx, &workflowservice.RegisterNamespaceRequest{ Namespace: invalidNamespace, WorkflowExecutionRetentionPeriod: durationpb.New(time.Hour * 24), }) s.NoError(err) // Send an invalid completion request and verify that we get an error that the namespace in the URL doesn't match the namespace in the token. - invalidCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(invalidNamespace) + invalidCallbackURL := "http://" + env.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(invalidNamespace) completion := nexusrpc.CompleteOperationOptions{ Result: testcore.MustToPayload(s.T(), "result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } - _, err = s.sendNexusCompletionRequest(ctx, invalidCallbackURL, completion) + _, err = s.sendNexusCompletionRequest(ctx, env, invalidCallbackURL, completion) // Verify we get the correct error response s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) @@ -704,11 +703,11 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { s.NoError(err) completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} - snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + snap, err = s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_not_found"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "error_not_found"}) // Request fails if the state machine reference is stale. staleToken := common.CloneProto(completionToken) @@ -717,20 +716,20 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { s.NoError(err) completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} - snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + snap, err = s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_not_found"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "error_not_found"}) callbackToken, err = gen.Tokenize(completionToken) s.NoError(err) completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} - snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + snap, err = s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) s.NoError(err) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "success"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "success"}) // Ensure that CompleteOperation request is tracked as part of normal service telemetry metrics idx := slices.IndexFunc(snap["service_requests"], func(m *metricstest.CapturedRecording) bool { opTag, ok := m.Tags["operation"] @@ -739,15 +738,15 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { s.Greater(idx, -1) // Resend the request and verify we get a not found error since the operation has already completed. - snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + snap, err = s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_not_found"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "error_not_found"}) // Poll again and verify the completion is recorded and triggers workflow progress. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -760,7 +759,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { }) s.Positive(completedEventIdx) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -784,8 +783,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { s.Equal("result", result) // Reset the workflow and check that the completion event has been reapplied. - resp, err := s.FrontendClient().ResetWorkflowExecution(ctx, &workflowservice.ResetWorkflowExecutionRequest{ - Namespace: s.Namespace().String(), + resp, err := env.FrontendClient().ResetWorkflowExecution(ctx, &workflowservice.ResetWorkflowExecutionRequest{ + Namespace: env.Namespace().String(), WorkflowExecution: pollResp.WorkflowExecution, Reason: "test", RequestId: uuid.NewString(), @@ -793,7 +792,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { }) s.NoError(err) - hist := s.SdkClient().GetWorkflowHistory(ctx, run.GetID(), resp.RunId, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + hist := env.SdkClient().GetWorkflowHistory(ctx, run.GetID(), resp.RunId, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) seenCompletedEvent := false for hist.HasNext() { @@ -808,8 +807,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { // Reset the workflow again to the same point with enumspb.RESET_REAPPLY_EXCLUDE_TYPE_NEXUS option // and verify that the completion event has been excluded. - resp, err = s.FrontendClient().ResetWorkflowExecution(ctx, &workflowservice.ResetWorkflowExecutionRequest{ - Namespace: s.Namespace().String(), + resp, err = env.FrontendClient().ResetWorkflowExecution(ctx, &workflowservice.ResetWorkflowExecutionRequest{ + Namespace: env.Namespace().String(), WorkflowExecution: pollResp.WorkflowExecution, Reason: "test", RequestId: uuid.NewString(), @@ -818,7 +817,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { }) s.NoError(err) - hist = s.SdkClient().GetWorkflowHistory(ctx, run.GetID(), resp.RunId, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + hist = env.SdkClient().GetWorkflowHistory(ctx, run.GetID(), resp.RunId, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) seenCompletedEvent = false for hist.HasNext() { @@ -833,6 +832,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { } func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueues := []string{testcore.RandomizeStr(s.T().Name()), testcore.RandomizeStr(s.T().Name())} wfRuns := []client.WorkflowRun{} @@ -842,7 +842,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() completionWFID := testcore.RandomizeStr(s.T().Name()) completionWFTaskQueue := testcore.RandomizeStr(s.T().Name()) completionWFStartReq := &workflowservice.StartWorkflowExecutionRequest{ - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), WorkflowId: completionWFID, WorkflowType: &commonpb.WorkflowType{Name: completionWFType}, TaskQueue: &taskqueuepb.TaskQueue{Name: completionWFTaskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, @@ -863,13 +863,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() // The second workflow will have its callback attached to the running workflow. for _, tq := range taskQueues { endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ Variant: &nexuspb.EndpointTarget_Worker_{ Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), TaskQueue: tq, }, }, @@ -878,15 +878,15 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: tq, }, "workflow") s.NoError(err) wfRuns = append(wfRuns, run) // Poll workflow task, and schedule Nexus operation. - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: tq, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -894,7 +894,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -914,8 +914,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() s.NoError(err) // Poll Nexus task - nexusTask, err := s.FrontendClient().PollNexusTaskQueue(ctx, &workflowservice.PollNexusTaskQueueRequest{ - Namespace: s.Namespace().String(), + nexusTask, err := env.FrontendClient().PollNexusTaskQueue(ctx, &workflowservice.PollNexusTaskQueueRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskQueue: &taskqueuepb.TaskQueue{ Name: tq, @@ -944,7 +944,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() completionWFStartRequestIDs = append(completionWFStartRequestIDs, completionWFStartReq.RequestId) // Start workflow (first request) or attach callback (second request) - completionRun, err := s.FrontendClient().StartWorkflowExecution(ctx, completionWFStartReq) + completionRun, err := env.FrontendClient().StartWorkflowExecution(ctx, completionWFStartReq) s.NoError(err) completionWfRunIDs = append(completionWfRunIDs, completionRun.RunId) } @@ -958,8 +958,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() s.Equal(completionWfRunIDs[0], completionWfRunIDs[1]) // Complete workflow containing callback - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: completionWFTaskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -967,7 +967,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -989,7 +989,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() expectedLinks := []*commonpb.Link_WorkflowEvent{ { - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), WorkflowId: completionWFID, RunId: completionWfRunIDs[0], Reference: &commonpb.Link_WorkflowEvent_EventRef{ @@ -1000,7 +1000,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() }, }, { - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), WorkflowId: completionWFID, RunId: completionWfRunIDs[1], Reference: &commonpb.Link_WorkflowEvent_RequestIdRef{ @@ -1014,8 +1014,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() for i, tq := range taskQueues { // Poll and verify the fabricated start event and completion event are recorded and triggers workflow progress. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: tq, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1037,8 +1037,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() s.Positive(completedEventIdx) // Complete start request to verify response is ignored. - _, err = s.FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ - Namespace: s.Namespace().String(), + _, err = env.FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: env.Namespace().String(), Identity: uuid.NewString(), TaskToken: nexusTasks[i].TaskToken, Response: &nexuspb.Response{ @@ -1056,7 +1056,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() s.NoErrorf(err, "Duplicate start response should be ignored.") // Complete caller workflow - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1083,6 +1083,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() } func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) @@ -1099,7 +1100,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { listenAddr := nexustest.AllocListenAddress() nexustest.NewNexusServer(s.T(), listenAddr, h) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ @@ -1113,14 +1114,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) s.EventuallyWithT(func(t *assert.CollectT) { - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1128,7 +1129,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { Identity: "test", }) require.NoError(t, err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1149,8 +1150,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { }, time.Second*20, time.Millisecond*200) // Poll and verify that the "started" event was recorded. - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1158,7 +1159,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, }) @@ -1174,14 +1175,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { Error: nexus.NewOperationFailedErrorf("test operation failed"), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + snap, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) s.NoError(err) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "success"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "success"}) // Poll again and verify the completion is recorded and triggers workflow progress. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1194,7 +1195,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { }) s.Positive(completedEventIdx) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1222,45 +1223,48 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { ctx := testcore.NewContext() - s.Run("NamespaceNotFound", func() { + s.Run("NamespaceNotFound", func(s *NexusWorkflowTestSuite) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) // Generate a token with a non-existent namespace ID tokenWithBadNamespace, err := s.generateValidCallbackToken("namespace-doesnt-exist-id", testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path("namespace-doesnt-exist") + publicCallbackURL := "http://" + env.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path("namespace-doesnt-exist") completion := nexusrpc.CompleteOperationOptions{ Result: testcore.MustToPayload(s.T(), "result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: tokenWithBadNamespace}, } - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + snap, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Len(snap["nexus_completion_request_preprocess_errors"], 1) }) - s.Run("NamespaceNotFoundNoIdentifier", func() { + s.Run("NamespaceNotFoundNoIdentifier", func(s *NexusWorkflowTestSuite) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) // Generate a token with a non-existent namespace ID tokenWithBadNamespace, err := s.generateValidCallbackToken("namespace-doesnt-exist-id", testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier + publicCallbackURL := "http://" + env.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier completion := nexusrpc.CompleteOperationOptions{ Result: testcore.MustToPayload(s.T(), "result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: tokenWithBadNamespace}, } - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + snap, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Len(snap["nexus_completion_request_preprocess_errors"], 1) }) - s.Run("OperationTokenTooLong", func() { - publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) + s.Run("OperationTokenTooLong", func(s *NexusWorkflowTestSuite) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) + publicCallbackURL := "http://" + env.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(env.Namespace().String()) // Generate a valid callback token to get past initial validation - namespaceID := s.GetNamespaceID(s.Namespace().String()) + namespaceID := env.NamespaceID().String() validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) completion := nexusrpc.CompleteOperationOptions{ @@ -1269,19 +1273,20 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { Header: nexus.Header{commonnexus.CallbackTokenHeader: validToken}, } - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + snap, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Empty(snap["nexus_completion_request_preprocess_errors"]) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "error_bad_request"}) }) - s.Run("OperationTokenTooLongNoIdentifier", func() { - publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier + s.Run("OperationTokenTooLongNoIdentifier", func(s *NexusWorkflowTestSuite) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) + publicCallbackURL := "http://" + env.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier // Generate a valid callback token to get past initial validation - namespaceID := s.GetNamespaceID(s.Namespace().String()) + namespaceID := env.NamespaceID().String() validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) @@ -1291,23 +1296,24 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { Header: nexus.Header{commonnexus.CallbackTokenHeader: validToken}, } - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + snap, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Empty(snap["nexus_completion_request_preprocess_errors"]) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "error_bad_request"}) }) - s.Run("InvalidCallbackToken", func() { + s.Run("InvalidCallbackToken", func(s *NexusWorkflowTestSuite) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) completion := nexusrpc.CompleteOperationOptions{ Result: testcore.MustToPayload(s.T(), "result"), } - publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) + publicCallbackURL := "http://" + env.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(env.Namespace().String()) // metrics collection is not initialized before callback validation // Send request without callback token, helper does not add token if blank - _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + _, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) // Verify we get the correct error response var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) @@ -1315,14 +1321,15 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.Contains(handlerErr.Error(), "invalid callback token", "Response should indicate invalid callback token") }) - s.Run("InvalidCallbackTokenNoIdentifier", func() { + s.Run("InvalidCallbackTokenNoIdentifier", func(s *NexusWorkflowTestSuite) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) completion := nexusrpc.CompleteOperationOptions{ Result: testcore.MustToPayload(s.T(), "result"), } - publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier + publicCallbackURL := "http://" + env.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier // metrics collection is not initialized before callback validation // Send request without callback token, helper does not add token if blank - _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + _, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) // Verify we get the correct error response var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) @@ -1330,13 +1337,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.Contains(handlerErr.Error(), "invalid callback token", "Response should indicate invalid callback token") }) - s.Run("InvalidClientVersion", func() { - publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + s.Run("InvalidClientVersion", func(s *NexusWorkflowTestSuite) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) + publicCallbackURL := "http://" + env.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(env.Namespace().String()) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) // Generate a valid callback token to get past initial validation - namespaceID := s.GetNamespaceID(s.Namespace().String()) + namespaceID := env.NamespaceID().String() validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) @@ -1356,16 +1364,17 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unsupported_client"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "unsupported_client"}) }) - s.Run("InvalidClientVersionNoIdentifier", func() { - publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + s.Run("InvalidClientVersionNoIdentifier", func(s *NexusWorkflowTestSuite) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) + publicCallbackURL := "http://" + env.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) // Generate a valid callback token to get past initial validation - namespaceID := s.GetNamespaceID(s.Namespace().String()) + namespaceID := env.NamespaceID().String() validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) @@ -1386,11 +1395,12 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unsupported_client"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "unsupported_client"}) }) } func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrors() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() onAuthorize := func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { @@ -1399,11 +1409,11 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrors() { } return authorization.Result{Decision: authorization.DecisionAllow}, nil } - s.GetTestCluster().Host().SetOnAuthorize(onAuthorize) - defer s.GetTestCluster().Host().SetOnAuthorize(nil) + env.GetTestCluster().Host().SetOnAuthorize(onAuthorize) + defer env.GetTestCluster().Host().SetOnAuthorize(nil) // Generate a valid callback token for testing - namespaceID := s.GetNamespaceID(s.Namespace().String()) + namespaceID := env.NamespaceID().String() callbackToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) @@ -1412,16 +1422,17 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrors() { Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } - publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + publicCallbackURL := "http://" + env.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(env.Namespace().String()) + snap, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unauthorized"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "unauthorized"}) } func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrorsNoIdentifier() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() onAuthorize := func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { @@ -1430,11 +1441,11 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrorsNoId } return authorization.Result{Decision: authorization.DecisionAllow}, nil } - s.GetTestCluster().Host().SetOnAuthorize(onAuthorize) - defer s.GetTestCluster().Host().SetOnAuthorize(nil) + env.GetTestCluster().Host().SetOnAuthorize(onAuthorize) + defer env.GetTestCluster().Host().SetOnAuthorize(nil) // Generate a valid callback token for testing - namespaceID := s.GetNamespaceID(s.Namespace().String()) + namespaceID := env.NamespaceID().String() callbackToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) @@ -1442,18 +1453,19 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrorsNoId Result: testcore.MustToPayload(s.T(), "result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } - publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + publicCallbackURL := "http://" + env.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier + snap, err := s.sendNexusCompletionRequest(ctx, env, publicCallbackURL, completion) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) s.Len(snap["nexus_completion_requests"], 1) - s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unauthorized"}) + s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "outcome": "unauthorized"}) } func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) // Set URL template with invalid host - s.OverrideDynamicConfig( + env.OverrideDynamicConfig( nexusoperations.CallbackURLTemplate, "http://INTERNAL/namespaces/{{.NamespaceName}}/nexus/callback") @@ -1461,13 +1473,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ Variant: &nexuspb.EndpointTarget_Worker_{ Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), TaskQueue: taskQueue, }, }, @@ -1476,7 +1488,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) @@ -1485,7 +1497,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() completionWFTaskQueue := testcore.RandomizeStr(s.T().Name()) completionWFStartReq := &workflowservice.StartWorkflowExecutionRequest{ RequestId: uuid.NewString(), - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), WorkflowId: testcore.RandomizeStr(s.T().Name()), WorkflowType: &commonpb.WorkflowType{Name: completionWFType}, TaskQueue: &taskqueuepb.TaskQueue{Name: completionWFTaskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, @@ -1494,9 +1506,9 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() Identity: "test", } - pollerErrCh := s.nexusTaskPoller(ctx, taskQueue, func(res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { + pollerErrCh := env.nexusTaskPoller(ctx, s.T(), taskQueue, func(t *testing.T, res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error) { start := res.Request.Variant.(*nexuspb.Request_StartOperation).StartOperation - s.Equal(op.Name(), start.Operation) + require.Equal(t, op.Name(), start.Operation) completionWFStartReq.CompletionCallbacks = []*commonpb.Callback{ { @@ -1509,7 +1521,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() }, } - _, err := s.FrontendClient().StartWorkflowExecution(ctx, completionWFStartReq) + _, err := env.FrontendClient().StartWorkflowExecution(ctx, completionWFStartReq) if err != nil { return nil, err } @@ -1517,8 +1529,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() return &nexusTaskResponse{StartResult: &nexus.HandlerStartOperationResultAsync{OperationToken: "test-token"}}, nil }) - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1526,7 +1538,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1546,8 +1558,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() s.NoError(err) // Poll and verify that the "started" event was recorded. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1555,7 +1567,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, }) @@ -1566,8 +1578,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() s.Positive(startedEventIdx) // Complete workflow containing callback - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: completionWFTaskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1575,7 +1587,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1596,8 +1608,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() s.NoError(err) // Poll again and verify the completion is recorded and triggers workflow progress. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1610,7 +1622,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() }) s.Positive(completedEventIdx) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1637,6 +1649,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionInternalAuth() } func (s *NexusWorkflowTestSuite) TestNexusOperationCancelBeforeStarted_CancelationEventuallyDelivered() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) @@ -1661,7 +1674,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelBeforeStarted_Cancelati listenAddr := nexustest.AllocListenAddress() nexustest.NewNexusServer(s.T(), listenAddr, h) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ @@ -1675,13 +1688,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelBeforeStarted_Cancelati }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1689,7 +1702,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelBeforeStarted_Cancelati Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1719,8 +1732,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelBeforeStarted_Cancelati s.NoError(err) // Poll and cancel the operation. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1736,7 +1749,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelBeforeStarted_Cancelati s.Positive(scheduledEventIdx) scheduledEventID := pollResp.History.Events[scheduledEventIdx].EventId - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1753,14 +1766,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationCancelBeforeStarted_Cancelati s.NoError(err) canStartCh <- struct{}{} - s.WaitForChannel(ctx, cancelSentCh) + env.WaitForChannel(ctx, cancelSentCh) // Terminate the workflow for good measure. - err = s.SdkClient().TerminateWorkflow(ctx, run.GetID(), run.GetRunID(), "test") + err = env.SdkClient().TerminateWorkflow(ctx, run.GetID(), run.GetRunID(), "test") s.NoError(err) // Assert that we did not send a cancel request until after the operation was started. - hist := s.GetHistory(s.Namespace().String(), &commonpb.WorkflowExecution{ + hist := env.GetHistory(env.Namespace().String(), &commonpb.WorkflowExecution{ WorkflowId: run.GetID(), RunId: run.GetRunID(), }) @@ -1770,6 +1783,7 @@ NexusOperationStarted`, hist) } func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) @@ -1786,7 +1800,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { listenAddr := nexustest.AllocListenAddress() nexustest.NewNexusServer(s.T(), listenAddr, h) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ @@ -1800,13 +1814,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1814,7 +1828,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1834,8 +1848,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { s.NoError(err) // Poll and verify that the "started" event was recorded. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1843,7 +1857,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, }) @@ -1859,8 +1873,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { wftCompletedEventID := int64(len(pollResp.History.Events)) // Reset the workflow and check that the started event has been reapplied. - resetResp, err := s.FrontendClient().ResetWorkflowExecution(ctx, &workflowservice.ResetWorkflowExecutionRequest{ - Namespace: s.Namespace().String(), + resetResp, err := env.FrontendClient().ResetWorkflowExecution(ctx, &workflowservice.ResetWorkflowExecutionRequest{ + Namespace: env.Namespace().String(), WorkflowExecution: pollResp.WorkflowExecution, Reason: "test", RequestId: uuid.NewString(), @@ -1868,7 +1882,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { }) s.NoError(err) - hist := s.SdkClient().GetWorkflowHistory(ctx, run.GetID(), resetResp.RunId, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + hist := env.SdkClient().GetWorkflowHistory(ctx, run.GetID(), resetResp.RunId, false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) seenStartedEvent := false for hist.HasNext() { @@ -1883,12 +1897,12 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { Result: testcore.MustToPayload(s.T(), "result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } - _, err = s.sendNexusCompletionRequest(ctx, publicCallbackUrl, completion) + _, err = s.sendNexusCompletionRequest(ctx, env, publicCallbackUrl, completion) s.NoError(err) // Poll again and verify the completion is recorded and triggers workflow progress. - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -1901,7 +1915,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { }) s.Positive(completedEventIdx) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -1921,12 +1935,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { }) s.NoError(err) var result string - run = s.SdkClient().GetWorkflow(ctx, run.GetID(), resetResp.RunId) + run = env.SdkClient().GetWorkflow(ctx, run.GetID(), resetResp.RunId) s.NoError(run.Get(ctx, &result)) s.Equal("result", result) } func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithNilIO() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) defer cancel() callerTaskQueue := testcore.RandomizeStr("caller_" + s.T().Name()) @@ -1934,13 +1949,13 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithNilIO() { endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) handlerWorkflowID := testcore.RandomizeStr(s.T().Name()) - _, err := s.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ Variant: &nexuspb.EndpointTarget_Worker_{ Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), TaskQueue: callerTaskQueue, }, }, @@ -1950,7 +1965,7 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithNilIO() { s.NoError(err) w := worker.New( - s.SdkClient(), + env.SdkClient(), callerTaskQueue, worker.Options{}, ) @@ -1983,21 +1998,21 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithNilIO() { w.Start() defer w.Stop() - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: callerTaskQueue, }, callerWF, nil) s.NoError(err) - pollRes, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollRes, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: handlerWorkflowTaskQueue, }, Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ - Namespace: s.Namespace().String(), + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + Namespace: env.Namespace().String(), TaskToken: pollRes.TaskToken, Identity: "test", Commands: []*commandpb.Command{ @@ -2014,7 +2029,7 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithNilIO() { s.NoError(err) s.NoError(run.Get(ctx, nil)) - history := s.SdkClient().GetWorkflowHistory(ctx, run.GetID(), "", false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) + history := env.SdkClient().GetWorkflowHistory(ctx, run.GetID(), "", false, enumspb.HISTORY_EVENT_FILTER_TYPE_ALL_EVENT) for history.HasNext() { ev, err := history.Next() s.NoError(err) @@ -2026,354 +2041,358 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithNilIO() { } func (s *NexusWorkflowTestSuite) TestNexusSyncOperationErrorRehydration() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) - defer cancel() - taskQueue := testcore.RandomizeStr("caller_" + s.T().Name()) - endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) - converter := temporal.NewDefaultFailureConverter(temporal.DefaultFailureConverterOptions{}) - - _, err := s.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ - Spec: &nexuspb.EndpointSpec{ - Name: endpointName, - Target: &nexuspb.EndpointTarget{ - Variant: &nexuspb.EndpointTarget_Worker_{ - Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), - TaskQueue: taskQueue, - }, - }, - }, - }, - }) - s.NoError(err) - - w := worker.New( - s.SdkClient(), - taskQueue, - worker.Options{}, - ) - - svc := nexus.NewService("test") - op := nexus.NewSyncOperation("op", func(ctx context.Context, outcome string, soo nexus.StartOperationOptions) (nexus.NoValue, error) { - switch outcome { - case "fail-handler-internal": - return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "intentional internal error") - case "fail-handler-app-error": - return nil, temporal.NewApplicationError("app error", "TestError", "details") - case "fail-handler-bad-request": - return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "bad request") - case "fail-operation": - return nil, nexus.NewOperationFailedErrorf("some error") - case "fail-operation-app-error": - return nil, temporal.NewNonRetryableApplicationError("app error", "TestError", nil, "details") - } - return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected outcome: %s", outcome) - }) - s.NoError(svc.Register(op)) - - callerWF := func(ctx workflow.Context, outcome string) (nexus.NoValue, error) { - c := workflow.NewNexusClient(endpointName, svc.Name) - fut := c.ExecuteOperation(ctx, op, outcome, workflow.NexusOperationOptions{}) - return nil, fut.Get(ctx, nil) - } - - w.RegisterNexusService(svc) - w.RegisterWorkflow(callerWF) - s.NoError(w.Start()) - s.T().Cleanup(w.Stop) - - cases := []struct { + type testcase struct { outcome string metricsOutcome string - checkPendingError func(t *testing.T, pendingErr error) - checkWorkflowError func(t *testing.T, wfErr error) - }{ + checkPendingError func(s *NexusWorkflowTestSuite, pendingErr error) + checkWorkflowError func(s *NexusWorkflowTestSuite, wfErr error) + } + cases := []testcase{ { outcome: "fail-handler-internal", metricsOutcome: "handler-error:INTERNAL", - checkPendingError: func(t *testing.T, pendingErr error) { + checkPendingError: func(s *NexusWorkflowTestSuite, pendingErr error) { var handlerErr *nexus.HandlerError - require.ErrorAs(t, pendingErr, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) - require.Equal(t, "intentional internal error", handlerErr.Message) + s.ErrorAs(pendingErr, &handlerErr) + s.Equal(nexus.HandlerErrorTypeInternal, handlerErr.Type) + s.Equal("intentional internal error", handlerErr.Message) }, }, { outcome: "fail-handler-app-error", metricsOutcome: "handler-error:INTERNAL", - checkPendingError: func(t *testing.T, pendingErr error) { + checkPendingError: func(s *NexusWorkflowTestSuite, pendingErr error) { var handlerErr *nexus.HandlerError - require.ErrorAs(t, pendingErr, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) + s.ErrorAs(pendingErr, &handlerErr) + s.Equal(nexus.HandlerErrorTypeInternal, handlerErr.Type) var appErr *temporal.ApplicationError - require.ErrorAs(t, handlerErr.Cause, &appErr) - require.Equal(t, "app error", appErr.Message()) - require.Equal(t, "TestError", appErr.Type()) + s.ErrorAs(handlerErr.Cause, &appErr) + s.Equal("app error", appErr.Message()) + s.Equal("TestError", appErr.Type()) var details string - require.NoError(t, appErr.Details(&details)) - require.Equal(t, "details", details) + s.NoError(appErr.Details(&details)) + s.Equal("details", details) }, }, { outcome: "fail-handler-bad-request", metricsOutcome: "handler-error:BAD_REQUEST", - checkWorkflowError: func(t *testing.T, wfErr error) { + checkWorkflowError: func(s *NexusWorkflowTestSuite, wfErr error) { var opErr *temporal.NexusOperationError - require.ErrorAs(t, wfErr, &opErr) + s.ErrorAs(wfErr, &opErr) var handlerErr *nexus.HandlerError - require.ErrorAs(t, opErr, &handlerErr) - require.Equal(t, nexus.HandlerErrorTypeBadRequest, handlerErr.Type) - require.Equal(t, "bad request", handlerErr.Message) + s.ErrorAs(opErr, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + s.Equal("bad request", handlerErr.Message) }, }, { outcome: "fail-operation", metricsOutcome: "operation-unsuccessful:failed", - checkWorkflowError: func(t *testing.T, wfErr error) { + checkWorkflowError: func(s *NexusWorkflowTestSuite, wfErr error) { var opErr *temporal.NexusOperationError - require.ErrorAs(t, wfErr, &opErr) - require.Equal(t, "nexus operation completed unsuccessfully", opErr.Message) + s.ErrorAs(wfErr, &opErr) + s.Equal("nexus operation completed unsuccessfully", opErr.Message) var appErr *temporal.ApplicationError - require.ErrorAs(t, opErr.Cause, &appErr) - require.Equal(t, "some error", appErr.Message()) + s.ErrorAs(opErr.Cause, &appErr) + s.Equal("some error", appErr.Message()) }, }, { outcome: "fail-operation-app-error", metricsOutcome: "handler-error:INTERNAL", - checkWorkflowError: func(t *testing.T, wfErr error) { + checkWorkflowError: func(s *NexusWorkflowTestSuite, wfErr error) { var opErr *temporal.NexusOperationError - require.ErrorAs(t, wfErr, &opErr) + s.ErrorAs(wfErr, &opErr) var appErr *temporal.ApplicationError - require.ErrorAs(t, opErr, &appErr) - require.Equal(t, "app error", appErr.Message()) - require.Equal(t, "TestError", appErr.Type()) + s.ErrorAs(opErr, &appErr) + s.Equal("app error", appErr.Message()) + s.Equal("TestError", appErr.Type()) var details string - require.NoError(t, appErr.Details(&details)) - require.Equal(t, "details", details) + s.NoError(appErr.Details(&details)) + s.Equal("details", details) }, }, } - for _, tc := range cases { - s.T().Run(tc.outcome, func(t *testing.T) { - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ - TaskQueue: taskQueue, - }, callerWF, tc.outcome) - s.NoError(err) - - if tc.checkPendingError != nil { - var f *failurepb.Failure - require.EventuallyWithT(t, func(t *assert.CollectT) { - desc, err := s.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) - require.NoError(t, err) - require.Len(t, desc.PendingNexusOperations, 1) - f = desc.PendingNexusOperations[0].LastAttemptFailure - require.NotNil(t, f) - - }, 10*time.Second, 100*time.Millisecond) - s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - tc.checkPendingError(t, converter.FailureToError(f)) - s.NoError(s.SdkClient().TerminateWorkflow(ctx, run.GetID(), run.GetRunID(), "test cleanup")) - } else { - wfErr := run.Get(ctx, nil) - s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - tc.checkWorkflowError(t, wfErr) - } - - snap := capture.Snapshot() - require.Len(t, snap["nexus_outbound_requests"], 1) - require.Subset( - t, - snap["nexus_outbound_requests"][0].Tags, - map[string]string{ - "namespace": s.Namespace().String(), - "method": "StartOperation", - "failure_source": "worker", - "outcome": tc.metricsOutcome, - }, - ) - }) - - } -} - -func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationErrorRehydration() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) - defer cancel() - testCtx := ctx - taskQueue := testcore.RandomizeStr("caller_" + s.T().Name()) - endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) - handlerWorkflowID := testcore.RandomizeStr(s.T().Name()) + testFn := func(s *NexusWorkflowTestSuite, tc testcase) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + defer cancel() + taskQueue := testcore.RandomizeStr("caller_" + s.T().Name()) + endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) + converter := temporal.NewDefaultFailureConverter(temporal.DefaultFailureConverterOptions{}) - _, err := s.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ - Spec: &nexuspb.EndpointSpec{ - Name: endpointName, - Target: &nexuspb.EndpointTarget{ - Variant: &nexuspb.EndpointTarget_Worker_{ - Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), - TaskQueue: taskQueue, + _, err := env.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + Spec: &nexuspb.EndpointSpec{ + Name: endpointName, + Target: &nexuspb.EndpointTarget{ + Variant: &nexuspb.EndpointTarget_Worker_{ + Worker: &nexuspb.EndpointTarget_Worker{ + Namespace: env.Namespace().String(), + TaskQueue: taskQueue, + }, }, }, }, - }, - }) - s.NoError(err) - - w := worker.New( - s.SdkClient(), - taskQueue, - worker.Options{}, - ) + }) + s.NoError(err) - svc := nexus.NewService("test") + w := worker.New(env.SdkClient(), taskQueue, worker.Options{}) + svc := nexus.NewService("test") + op := nexus.NewSyncOperation("op", func(ctx context.Context, outcome string, soo nexus.StartOperationOptions) (nexus.NoValue, error) { + switch outcome { + case "fail-handler-internal": + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "intentional internal error") + case "fail-handler-app-error": + return nil, temporal.NewApplicationError("app error", "TestError", "details") + case "fail-handler-bad-request": + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "bad request") + case "fail-operation": + return nil, nexus.NewOperationFailedErrorf("some error") + case "fail-operation-app-error": + return nil, temporal.NewNonRetryableApplicationError("app error", "TestError", nil, "details") + default: + } + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected outcome: %s", outcome) + }) + s.NoError(svc.Register(op)) - handlerWF := func(ctx workflow.Context, outcome string) (nexus.NoValue, error) { - switch outcome { - case "wait", "timeout": - // Wait for the workflow to be canceled. - return nil, workflow.Await(ctx, func() bool { return false }) - case "fail": - return nil, temporal.NewApplicationError("app error", "TestError", "details") + callerWF := func(ctx workflow.Context, outcome string) (nexus.NoValue, error) { + c := workflow.NewNexusClient(endpointName, svc.Name) + fut := c.ExecuteOperation(ctx, op, outcome, workflow.NexusOperationOptions{}) + return nil, fut.Get(ctx, nil) } - return nil, fmt.Errorf("unexpected outcome: %s", outcome) - } - op := temporalnexus.NewWorkflowRunOperation("op", handlerWF, func(ctx context.Context, outcome string, soo nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { - var workflowExecutionTimeout time.Duration - if outcome == "timeout" { - workflowExecutionTimeout = time.Second - } - return client.StartWorkflowOptions{ID: handlerWorkflowID, WorkflowExecutionTimeout: workflowExecutionTimeout}, nil - }) - s.NoError(svc.Register(op)) + w.RegisterNexusService(svc) + w.RegisterWorkflow(callerWF) + s.NoError(w.Start()) + defer w.Stop() - callerWF := func(ctx workflow.Context, outcome, action string) (nexus.NoValue, error) { - opCtx, cancel := workflow.WithCancel(ctx) - defer cancel() - c := workflow.NewNexusClient(endpointName, svc.Name) - fut := c.ExecuteOperation(opCtx, op, outcome, workflow.NexusOperationOptions{}) - var exec workflow.NexusOperationExecution - if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil { - return nil, err - } - switch action { - case "terminate": - // Lazy man's version of a local activity, don't try this at home. - workflow.SideEffect(ctx, func(ctx workflow.Context) any { - err := s.SdkClient().TerminateWorkflow(testCtx, handlerWorkflowID, "", "") - if err != nil { - panic(err) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + TaskQueue: taskQueue, + }, callerWF, tc.outcome) + s.NoError(err) + + if tc.checkPendingError != nil { + var f *failurepb.Failure + s.EventuallyWithT(func(t *assert.CollectT) { + desc, err := env.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) + assert.NoError(t, err) + assert.Len(t, desc.PendingNexusOperations, 1) + if len(desc.PendingNexusOperations) > 0 { + f = desc.PendingNexusOperations[0].LastAttemptFailure + assert.NotNil(t, f) } - return nil - }) - case "cancel": - cancel() - err := fut.Get(ctx, nil) - // The Go SDK unwraps CanceledErrors when an error is returned from the workflow, assert in-workflow. - var opErr *temporal.NexusOperationError - if !errors.As(err, &opErr) { - return nil, fmt.Errorf("expected NexusOperationError, got %w", err) - } - var canceledErr *temporal.CanceledError - if !errors.As(opErr, &canceledErr) { - return nil, fmt.Errorf("expected CanceledError, got %w", err) - } + }, 10*time.Second, 100*time.Millisecond) + env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + tc.checkPendingError(s, converter.FailureToError(f)) + s.NoError(env.SdkClient().TerminateWorkflow(ctx, run.GetID(), run.GetRunID(), "test cleanup")) + } else { + wfErr := run.Get(ctx, nil) + env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + tc.checkWorkflowError(s, wfErr) } - return nil, fut.Get(ctx, nil) + + snap := capture.Snapshot() + s.Len(snap["nexus_outbound_requests"], 1) + s.Subset( + snap["nexus_outbound_requests"][0].Tags, + map[string]string{ + "namespace": env.Namespace().String(), + "method": "StartOperation", + "failure_source": "worker", + "outcome": tc.metricsOutcome, + }, + ) } - w.RegisterNexusService(svc) - w.RegisterWorkflow(callerWF) - w.RegisterWorkflow(handlerWF) - s.NoError(w.Start()) - s.T().Cleanup(w.Stop) + for _, tc := range cases { + s.Run(tc.outcome, func(s *NexusWorkflowTestSuite) { + testFn(s, tc) + }) + } +} - cases := []struct { +func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationErrorRehydration() { + type testcase struct { outcome, action string - checkWorkflowError func(t *testing.T, wfErr error) - }{ + checkWorkflowError func(s *NexusWorkflowTestSuite, wfErr error) + } + cases := []testcase{ { outcome: "fail", - checkWorkflowError: func(t *testing.T, wfErr error) { + checkWorkflowError: func(s *NexusWorkflowTestSuite, wfErr error) { var opErr *temporal.NexusOperationError - require.ErrorAs(t, wfErr, &opErr) + s.ErrorAs(wfErr, &opErr) var appErr *temporal.ApplicationError - require.ErrorAs(t, opErr, &appErr) - require.Equal(t, "app error", appErr.Message()) - require.Equal(t, "TestError", appErr.Type()) + s.ErrorAs(opErr, &appErr) + s.Equal("app error", appErr.Message()) + s.Equal("TestError", appErr.Type()) var details string - require.NoError(t, appErr.Details(&details)) - require.Equal(t, "details", details) + s.NoError(appErr.Details(&details)) + s.Equal("details", details) }, }, { outcome: "wait", action: "terminate", - checkWorkflowError: func(t *testing.T, wfErr error) { + checkWorkflowError: func(s *NexusWorkflowTestSuite, wfErr error) { var opErr *temporal.NexusOperationError - require.ErrorAs(t, wfErr, &opErr) + s.ErrorAs(wfErr, &opErr) var termErr *temporal.TerminatedError - require.ErrorAs(t, opErr, &termErr) + s.ErrorAs(opErr, &termErr) }, }, { outcome: "wait", action: "cancel", - checkWorkflowError: func(t *testing.T, wfErr error) { + checkWorkflowError: func(s *NexusWorkflowTestSuite, wfErr error) { // The Go SDK loses the NexusOperationError (as well as any other error if it wraps a CanceledError), // assertions done in workflow. var canceledErr *temporal.CanceledError - require.ErrorAs(t, wfErr, &canceledErr) + s.ErrorAs(wfErr, &canceledErr) }, }, { outcome: "timeout", - checkWorkflowError: func(t *testing.T, wfErr error) { + checkWorkflowError: func(s *NexusWorkflowTestSuite, wfErr error) { var opErr *temporal.NexusOperationError - require.ErrorAs(t, wfErr, &opErr) + s.ErrorAs(wfErr, &opErr) var timeoutErr *temporal.TimeoutError - require.ErrorAs(t, opErr, &timeoutErr) + s.ErrorAs(opErr, &timeoutErr) }, }, } - for _, tc := range cases { - s.T().Run(tc.outcome+"-"+tc.action, func(t *testing.T) { - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ - TaskQueue: taskQueue, - }, callerWF, tc.outcome, tc.action) - s.NoError(err) + testFn := func(s *NexusWorkflowTestSuite, tc testcase) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + defer cancel() + testCtx := ctx + taskQueue := testcore.RandomizeStr("caller_" + s.T().Name()) + endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) + handlerWorkflowID := testcore.RandomizeStr(s.T().Name()) - wfErr := run.Get(ctx, nil) - s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - tc.checkWorkflowError(t, wfErr) + _, err := env.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + Spec: &nexuspb.EndpointSpec{ + Name: endpointName, + Target: &nexuspb.EndpointTarget{ + Variant: &nexuspb.EndpointTarget_Worker_{ + Worker: &nexuspb.EndpointTarget_Worker{ + Namespace: env.Namespace().String(), + TaskQueue: taskQueue, + }, + }, + }, + }, + }) + s.NoError(err) + + w := worker.New(env.SdkClient(), taskQueue, worker.Options{}) + svc := nexus.NewService("test") + + handlerWF := func(ctx workflow.Context, outcome string) (nexus.NoValue, error) { + switch outcome { + case "wait", "timeout": + // Wait for the workflow to be canceled. + return nil, workflow.Await(ctx, func() bool { return false }) + case "fail": + return nil, temporal.NewApplicationError("app error", "TestError", "details") + default: + } + return nil, fmt.Errorf("unexpected outcome: %s", outcome) + } - snap := capture.Snapshot() - require.GreaterOrEqual(t, len(snap["nexus_outbound_requests"]), 1) - require.Subset(t, snap["nexus_outbound_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "method": "StartOperation", "failure_source": "_unknown_", "outcome": "pending"}) + op := temporalnexus.NewWorkflowRunOperation("op", handlerWF, func(ctx context.Context, outcome string, soo nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { + var workflowExecutionTimeout time.Duration + if outcome == "timeout" { + workflowExecutionTimeout = time.Second + } + return client.StartWorkflowOptions{ID: handlerWorkflowID, WorkflowExecutionTimeout: workflowExecutionTimeout}, nil }) + s.NoError(svc.Register(op)) + callerWF := func(ctx workflow.Context, outcome, action string) (nexus.NoValue, error) { + opCtx, cancel := workflow.WithCancel(ctx) + defer cancel() + c := workflow.NewNexusClient(endpointName, svc.Name) + fut := c.ExecuteOperation(opCtx, op, outcome, workflow.NexusOperationOptions{}) + var exec workflow.NexusOperationExecution + if err := fut.GetNexusOperationExecution().Get(ctx, &exec); err != nil { + return nil, err + } + switch action { + case "terminate": + // Lazy man's version of a local activity, don't try this at home. + workflow.SideEffect(ctx, func(ctx workflow.Context) any { + err := env.SdkClient().TerminateWorkflow(testCtx, handlerWorkflowID, "", "") + if err != nil { + panic(err) + } + return nil + }) + case "cancel": + cancel() + err := fut.Get(ctx, nil) + // The Go SDK unwraps CanceledErrors when an error is returned from the workflow, assert in-workflow. + var opErr *temporal.NexusOperationError + if !errors.As(err, &opErr) { + return nil, fmt.Errorf("expected NexusOperationError, got %w", err) + } + var canceledErr *temporal.CanceledError + if !errors.As(opErr, &canceledErr) { + return nil, fmt.Errorf("expected CanceledError, got %w", err) + } + default: + } + return nil, fut.Get(ctx, nil) + } + + w.RegisterNexusService(svc) + w.RegisterWorkflow(callerWF) + w.RegisterWorkflow(handlerWF) + s.NoError(w.Start()) + defer w.Stop() + + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + TaskQueue: taskQueue, + }, callerWF, tc.outcome, tc.action) + s.NoError(err) + + wfErr := run.Get(ctx, nil) + env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + tc.checkWorkflowError(s, wfErr) + + snap := capture.Snapshot() + s.GreaterOrEqual(len(snap["nexus_outbound_requests"]), 1) + s.Subset(snap["nexus_outbound_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "method": "StartOperation", "failure_source": "_unknown_", "outcome": "pending"}) + } + + for _, tc := range cases { + s.Run(tc.outcome+"-"+tc.action, func(s *NexusWorkflowTestSuite) { + testFn(s, tc) + }) } } func (s *NexusWorkflowTestSuite) TestNexusCallbackAfterCallerComplete() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) defer cancel() taskQueue := testcore.RandomizeStr("caller_" + s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) handlerWorkflowID := testcore.RandomizeStr(s.T().Name()) - _, err := s.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ Variant: &nexuspb.EndpointTarget_Worker_{ Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), TaskQueue: taskQueue, }, }, @@ -2383,7 +2402,7 @@ func (s *NexusWorkflowTestSuite) TestNexusCallbackAfterCallerComplete() { s.NoError(err) w := worker.New( - s.SdkClient(), + env.SdkClient(), taskQueue, worker.Options{}, ) @@ -2416,18 +2435,18 @@ func (s *NexusWorkflowTestSuite) TestNexusCallbackAfterCallerComplete() { s.NoError(w.Start()) s.T().Cleanup(w.Stop) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, callerWF) s.NoError(err) s.NoError(run.Get(ctx, nil)) - err = s.SdkClient().SignalWorkflow(ctx, handlerWorkflowID, "", "test-signal", nil) + err = env.SdkClient().SignalWorkflow(ctx, handlerWorkflowID, "", "test-signal", nil) s.NoError(err) s.EventuallyWithT(func(ct *assert.CollectT) { - resp, err := s.FrontendClient().DescribeWorkflowExecution(ctx, &workflowservice.DescribeWorkflowExecutionRequest{ - Namespace: s.Namespace().String(), + resp, err := env.FrontendClient().DescribeWorkflowExecution(ctx, &workflowservice.DescribeWorkflowExecutionRequest{ + Namespace: env.Namespace().String(), Execution: &commonpb.WorkflowExecution{ WorkflowId: handlerWorkflowID, }, @@ -2441,6 +2460,7 @@ func (s *NexusWorkflowTestSuite) TestNexusCallbackAfterCallerComplete() { } func (s *NexusWorkflowTestSuite) TestNexusOperationSyncNexusFailure() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) @@ -2462,7 +2482,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncNexusFailure() { listenAddr := nexustest.AllocListenAddress() nexustest.NewNexusServer(s.T(), listenAddr, h) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ @@ -2477,7 +2497,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncNexusFailure() { s.NoError(err) w := worker.New( - s.SdkClient(), + env.SdkClient(), taskQueue, worker.Options{}, ) @@ -2492,13 +2512,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncNexusFailure() { s.NoError(w.Start()) s.T().Cleanup(w.Stop) - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, callerWF) s.NoError(err) wfErr := run.Get(ctx, nil) - s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) var handlerErr *nexus.HandlerError s.ErrorAs(wfErr, &handlerErr) @@ -2517,181 +2537,184 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncNexusFailure() { snap := capture.Snapshot() s.Len(snap["nexus_outbound_requests"], 1) // Confirming that requests which do not go through our frontend are not tagged with `failure_source` - s.Subset(snap["nexus_outbound_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "method": "StartOperation", "failure_source": "_unknown_", "outcome": "handler-error:BAD_REQUEST"}) + s.Subset(snap["nexus_outbound_requests"][0].Tags, map[string]string{"namespace": env.Namespace().String(), "method": "StartOperation", "failure_source": "_unknown_", "outcome": "handler-error:BAD_REQUEST"}) } func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithMultipleCallers() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) - defer cancel() - callerTaskQueue := testcore.RandomizeStr("caller_" + s.T().Name()) - endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) - handlerWorkflowID := testcore.RandomizeStr(s.T().Name()) - // number of concurrent Nexus operation calls numCalls := 5 - - _, err := s.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ - Spec: &nexuspb.EndpointSpec{ - Name: endpointName, - Target: &nexuspb.EndpointTarget{ - Variant: &nexuspb.EndpointTarget_Worker_{ - Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), - TaskQueue: callerTaskQueue, - }, - }, - }, - }, - }) - s.NoError(err) - - w := worker.New(s.SdkClient(), callerTaskQueue, worker.Options{}) - svc := nexus.NewService("test") handlerWf := func(ctx workflow.Context, input string) (string, error) { workflow.GetSignalChannel(ctx, "terminate").Receive(ctx, nil) return "hello " + input, nil } - - op := temporalnexus.NewWorkflowRunOperation( - "op", - handlerWf, - func(ctx context.Context, input string, opts nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { - var conflictPolicy enumspb.WorkflowIdConflictPolicy - if input == "conflict-policy-use-existing" { - conflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING - } - return client.StartWorkflowOptions{ - ID: handlerWorkflowID, - WorkflowIDConflictPolicy: conflictPolicy, - }, nil - }, - ) - svc.MustRegister(op) + handlerWorkflowID := testcore.RandomizeStr(s.T().Name()) + callerTaskQueue := testcore.RandomizeStr("caller_" + s.T().Name()) type CallerWfOutput struct { CntOk int CntErr int } + type CallerWfFn = func(ctx workflow.Context, input string) (CallerWfOutput, error) - callerWf := func(ctx workflow.Context, input string) (CallerWfOutput, error) { - output := CallerWfOutput{} - var retError error + buildNexusEnvFn := func(ctx context.Context, s *NexusWorkflowTestSuite) (*NexusTestEnv, CallerWfFn) { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) + endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) - c := workflow.NewNexusClient(endpointName, svc.Name) + _, err := env.SdkClient().OperatorService().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + Spec: &nexuspb.EndpointSpec{ + Name: endpointName, + Target: &nexuspb.EndpointTarget{ + Variant: &nexuspb.EndpointTarget_Worker_{ + Worker: &nexuspb.EndpointTarget_Worker{ + Namespace: env.Namespace().String(), + TaskQueue: callerTaskQueue, + }, + }, + }, + }, + }) + s.NoError(err) - nexusFutures := []workflow.NexusOperationFuture{} - for range numCalls { - fut := c.ExecuteOperation(ctx, op, input, workflow.NexusOperationOptions{}) - nexusFutures = append(nexusFutures, fut) - } + w := worker.New(env.SdkClient(), callerTaskQueue, worker.Options{}) + svc := nexus.NewService("test") + op := temporalnexus.NewWorkflowRunOperation( + "op", + handlerWf, + func(ctx context.Context, input string, opts nexus.StartOperationOptions) (client.StartWorkflowOptions, error) { + var conflictPolicy enumspb.WorkflowIdConflictPolicy + if input == "conflict-policy-use-existing" { + conflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING + } + return client.StartWorkflowOptions{ + ID: handlerWorkflowID, + WorkflowIDConflictPolicy: conflictPolicy, + }, nil + }, + ) + svc.MustRegister(op) - nexusOpStartedFutures := []workflow.NexusOperationFuture{} - for _, fut := range nexusFutures { - var exec workflow.NexusOperationExecution - err := fut.GetNexusOperationExecution().Get(ctx, &exec) - if err == nil { - output.CntOk++ - nexusOpStartedFutures = append(nexusOpStartedFutures, fut) - continue + callerWf := func(ctx workflow.Context, input string) (CallerWfOutput, error) { + output := CallerWfOutput{} + var retError error + + c := workflow.NewNexusClient(endpointName, svc.Name) + + nexusFutures := []workflow.NexusOperationFuture{} + for range numCalls { + fut := c.ExecuteOperation(ctx, op, input, workflow.NexusOperationOptions{}) + nexusFutures = append(nexusFutures, fut) } - output.CntErr++ - var handlerErr *nexus.HandlerError - var appErr *temporal.ApplicationError - if !errors.As(err, &handlerErr) { - retError = err - } else if !errors.As(handlerErr, &appErr) { - retError = err - } else if appErr.Type() != "WorkflowExecutionAlreadyStarted" { - retError = err + + nexusOpStartedFutures := []workflow.NexusOperationFuture{} + for _, fut := range nexusFutures { + var exec workflow.NexusOperationExecution + err := fut.GetNexusOperationExecution().Get(ctx, &exec) + if err == nil { + output.CntOk++ + nexusOpStartedFutures = append(nexusOpStartedFutures, fut) + continue + } + output.CntErr++ + var handlerErr *nexus.HandlerError + var appErr *temporal.ApplicationError + if !errors.As(err, &handlerErr) || !errors.As(handlerErr, &appErr) || appErr.Type() != "WorkflowExecutionAlreadyStarted" { + retError = err + } } - } - if output.CntOk > 0 { - // signal handler workflow so it will complete - err = workflow.SignalExternalWorkflow(ctx, handlerWorkflowID, "", "terminate", nil).Get(ctx, nil) - if err != nil { - return output, err + if output.CntOk > 0 { + // signal handler workflow so it will complete + err = workflow.SignalExternalWorkflow(ctx, handlerWorkflowID, "", "terminate", nil).Get(ctx, nil) + if err != nil { + return output, err + } } - } - for _, fut := range nexusOpStartedFutures { - var res string - err := fut.Get(ctx, &res) - if err != nil { - retError = err - } else if res != "hello "+input { - retError = fmt.Errorf("unexpected result from handler workflow: %q", res) + for _, fut := range nexusOpStartedFutures { + var res string + err := fut.Get(ctx, &res) + if err != nil { + retError = err + } else if res != "hello "+input { + retError = fmt.Errorf("unexpected result from handler workflow: %q", res) + } } + + return output, retError } - return output, retError - } + w.RegisterNexusService(svc) + w.RegisterWorkflow(handlerWf) + w.RegisterWorkflowWithOptions(callerWf, workflow.RegisterOptions{Name: "caller-wf"}) + s.NoError(w.Start()) - w.RegisterNexusService(svc) - w.RegisterWorkflow(handlerWf) - w.RegisterWorkflowWithOptions(callerWf, workflow.RegisterOptions{Name: "caller-wf"}) - s.NoError(w.Start()) - defer w.Stop() + // s.T().Cleanup(...) runs after the s.T()'s test finishes, not after this function returns + s.T().Cleanup(func() { w.Stop() }) + return env, callerWf + } testCases := []struct { input string - checkOutput func(t *testing.T, res CallerWfOutput, err error) + checkOutput func(s *NexusWorkflowTestSuite, env *NexusTestEnv, res CallerWfOutput, err error) }{ { input: "conflict-policy-fail", - checkOutput: func(t *testing.T, res CallerWfOutput, err error) { - require.NoError(t, err) - require.Equal(t, 1, res.CntOk) - require.Equal(t, numCalls-1, res.CntErr) + checkOutput: func(s *NexusWorkflowTestSuite, env *NexusTestEnv, res CallerWfOutput, err error) { + s.NoError(err) + s.Equal(1, res.CntOk) + s.Equal(numCalls-1, res.CntErr) // check the handler workflow has the request ID infos map correct - descResp, err := s.SdkClient().DescribeWorkflowExecution(context.Background(), handlerWorkflowID, "") - require.NoError(t, err) + descResp, err := env.SdkClient().DescribeWorkflowExecution(context.Background(), handlerWorkflowID, "") + s.NoError(err) requestIDInfos := descResp.GetWorkflowExtendedInfo().GetRequestIdInfos() - require.NotNil(t, requestIDInfos) - require.Len(t, requestIDInfos, 1) + s.NotNil(requestIDInfos) + s.Len(requestIDInfos, 1) for _, info := range requestIDInfos { - require.False(t, info.Buffered) - require.GreaterOrEqual(t, info.EventId, common.FirstEventID) - require.Equal(t, enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, info.EventType) + s.False(info.Buffered) + s.GreaterOrEqual(info.EventId, common.FirstEventID) + s.Equal(enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED, info.EventType) } }, }, { input: "conflict-policy-use-existing", - checkOutput: func(t *testing.T, res CallerWfOutput, err error) { - require.NoError(t, err) - require.Equal(t, numCalls, res.CntOk) - require.Equal(t, 0, res.CntErr) + checkOutput: func(s *NexusWorkflowTestSuite, env *NexusTestEnv, res CallerWfOutput, err error) { + s.NoError(err) + s.Equal(numCalls, res.CntOk) + s.Equal(0, res.CntErr) // check the handler workflow has the request ID infos map correct - descResp, err := s.SdkClient().DescribeWorkflowExecution(context.Background(), handlerWorkflowID, "") - require.NoError(t, err) + descResp, err := env.SdkClient().DescribeWorkflowExecution(context.Background(), handlerWorkflowID, "") + s.NoError(err) requestIDInfos := descResp.GetWorkflowExtendedInfo().GetRequestIdInfos() - require.NotNil(t, requestIDInfos) + s.NotNil(requestIDInfos) cntStarted := 0 cntAttached := 0 for _, info := range requestIDInfos { - require.False(t, info.Buffered) - require.GreaterOrEqual(t, info.EventId, common.FirstEventID) + s.False(info.Buffered) + s.GreaterOrEqual(info.EventId, common.FirstEventID) switch info.EventType { case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED: cntStarted++ case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED: cntAttached++ default: - require.Fail(t, "Unexpected event type in request ID info") + s.Fail("Unexpected event type in request ID info") } } - require.Equal(t, 1, cntStarted) - require.Equal(t, numCalls-1, cntAttached) + s.Equal(1, cntStarted) + s.Equal(numCalls-1, cntAttached) }, }, } for _, tc := range testCases { - s.Run(tc.input, func() { - run, err := s.SdkClient().ExecuteWorkflow( + s.Run(tc.input, func(s *NexusWorkflowTestSuite) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + defer cancel() + env, callerWf := buildNexusEnvFn(ctx, s) + run, err := env.SdkClient().ExecuteWorkflow( ctx, client.StartWorkflowOptions{ TaskQueue: callerTaskQueue, @@ -2702,23 +2725,24 @@ func (s *NexusWorkflowTestSuite) TestNexusAsyncOperationWithMultipleCallers() { s.NoError(err) var res CallerWfOutput err = run.Get(ctx, &res) - tc.checkOutput(s.T(), res, err) + tc.checkOutput(s, env, res, err) }) } } func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToCloseTimeout() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ Variant: &nexuspb.EndpointTarget_Worker_{ Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), TaskQueue: "unreachable-for-test", }, }, @@ -2727,14 +2751,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToCloseTimeout() { }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) // Schedule the operation with a short schedule-to-close timeout - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -2742,7 +2766,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToCloseTimeout() { Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -2762,14 +2786,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToCloseTimeout() { }) s.NoError(err) - descResp, err := s.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) + descResp, err := env.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) s.NoError(err) s.Len(descResp.PendingNexusOperations, 1) s.Equal(2*time.Second, descResp.PendingNexusOperations[0].ScheduleToCloseTimeout.AsDuration()) // Now wait for the timeout event - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -2788,7 +2812,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToCloseTimeout() { timedOutEvent.GetNexusOperationTimedOutEventAttributes().GetFailure().GetCause().GetTimeoutFailureInfo().GetTimeoutType()) // Complete the workflow - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -2805,17 +2829,18 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToCloseTimeout() { } func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToStartTimeout() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ Variant: &nexuspb.EndpointTarget_Worker_{ Worker: &nexuspb.EndpointTarget_Worker{ - Namespace: s.Namespace().String(), + Namespace: env.Namespace().String(), TaskQueue: "unreachable-for-test", }, }, @@ -2824,14 +2849,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToStartTimeout() { }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) // Schedule the operation with a short schedule-to-close timeout - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -2839,7 +2864,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToStartTimeout() { Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -2859,14 +2884,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToStartTimeout() { }) s.NoError(err) - descResp, err := s.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) + descResp, err := env.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) s.NoError(err) s.Len(descResp.PendingNexusOperations, 1) s.Equal(2*time.Second, descResp.PendingNexusOperations[0].ScheduleToStartTimeout.AsDuration()) // Now wait for the timeout event - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -2885,7 +2910,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToStartTimeout() { timedOutEvent.GetNexusOperationTimedOutEventAttributes().GetFailure().GetCause().GetTimeoutFailureInfo().GetTimeoutType()) // Complete the workflow - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -2902,6 +2927,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationScheduleToStartTimeout() { } func (s *NexusWorkflowTestSuite) TestNexusOperationStartToCloseTimeout() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) endpointName := testcore.RandomizedNexusEndpoint(s.T().Name()) @@ -2916,7 +2942,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationStartToCloseTimeout() { listenAddr := nexustest.AllocListenAddress() nexustest.NewNexusServer(s.T(), listenAddr, h) - _, err := s.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + _, err := env.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ Spec: &nexuspb.EndpointSpec{ Name: endpointName, Target: &nexuspb.EndpointTarget{ @@ -2930,14 +2956,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationStartToCloseTimeout() { }) s.NoError(err) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) // Schedule the operation with a short start-to-close timeout - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -2945,7 +2971,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationStartToCloseTimeout() { Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -2965,14 +2991,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationStartToCloseTimeout() { }) s.NoError(err) - descResp, err := s.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) + descResp, err := env.SdkClient().DescribeWorkflowExecution(ctx, run.GetID(), run.GetRunID()) s.NoError(err) s.Len(descResp.PendingNexusOperations, 1) s.Equal(2*time.Second, descResp.PendingNexusOperations[0].StartToCloseTimeout.AsDuration()) // Wait for the started event first - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -2988,15 +3014,15 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationStartToCloseTimeout() { s.Positive(startedEventIdx) // Respond to acknowledge the started event - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, }) s.NoError(err) // Now wait for the timeout event - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -3016,7 +3042,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationStartToCloseTimeout() { s.Contains(timedOutEvent.GetNexusOperationTimedOutEventAttributes().GetFailure().GetCause().GetMessage(), "operation timed out") // Complete the workflow - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -3061,11 +3087,12 @@ func (s *NexusWorkflowTestSuite) generateValidCallbackToken(namespaceID, workflo func (s *NexusWorkflowTestSuite) sendNexusCompletionRequest( ctx context.Context, + env *NexusTestEnv, url string, completion nexusrpc.CompleteOperationOptions, ) (map[string][]*metricstest.CapturedRecording, error) { - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) + capture := env.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + defer env.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ Serializer: commonnexus.PayloadSerializer, @@ -3077,16 +3104,17 @@ func (s *NexusWorkflowTestSuite) sendNexusCompletionRequest( // NOTE: This test cannot use the SDK workflow package because there is a restriction that prevents setting the // __temporal_system endpoint. func (s *NexusWorkflowTestSuite) TestNexusOperationSystemEndpoint() { + env := newNexusTestEnv(s.T(), true, testcore.WithDedicatedCluster()) ctx := testcore.NewContext() taskQueue := testcore.RandomizeStr(s.T().Name()) - run, err := s.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ + run, err := env.SdkClient().ExecuteWorkflow(ctx, client.StartWorkflowOptions{ TaskQueue: taskQueue, }, "workflow") s.NoError(err) - pollResp, err := s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err := env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -3094,7 +3122,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSystemEndpoint() { Identity: "test", }) s.NoError(err) - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ @@ -3114,8 +3142,8 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSystemEndpoint() { s.NoError(err) // Poll for the completion - pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ - Namespace: s.Namespace().String(), + pollResp, err = env.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: env.Namespace().String(), TaskQueue: &taskqueuepb.TaskQueue{ Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -3136,7 +3164,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSystemEndpoint() { s.NotNil(result) // Complete the workflow - _, err = s.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ + _, err = env.FrontendClient().RespondWorkflowTaskCompleted(ctx, &workflowservice.RespondWorkflowTaskCompletedRequest{ Identity: "test", TaskToken: pollResp.TaskToken, Commands: []*commandpb.Command{ diff --git a/tests/testcore/test_env.go b/tests/testcore/test_env.go index 50d191ece2f..7776d497954 100644 --- a/tests/testcore/test_env.go +++ b/tests/testcore/test_env.go @@ -25,6 +25,7 @@ import ( "go.temporal.io/server/common/testing/taskpoller" "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/testvars" + "go.temporal.io/server/components/nexusoperations" "google.golang.org/grpc" ) @@ -156,7 +157,8 @@ func NewEnv(t *testing.T, opts ...TestOption) *TestEnv { cluster := base.GetTestCluster() // Create a dedicated namespace for the test to help with test isolation. - ns := namespace.Name(RandomizeStr(t.Name())) + baseName := strings.ReplaceAll(t.Name(), "/", "-") + ns := namespace.Name(RandomizeStr(baseName)) nsID, err := base.RegisterNamespace( ns, 1, // 1 day retention @@ -182,6 +184,12 @@ func NewEnv(t *testing.T, opts ...TestOption) *TestEnv { sdkWorkerTQ: RandomizeStr("tq-" + t.Name()), } + // Set Nexus callback URL now that we have the cluster's HTTP address. Note that we set + // a default for the global config here so callers that rely on this can still use a shared cluster. + env.FunctionalTestBase.OverrideDynamicConfig( + nexusoperations.CallbackURLTemplate, + "http://"+env.HttpAPIAddress()+"/namespaces/{{.NamespaceName}}/nexus/callback") + // For shared clusters, apply all dynamic config settings as overrides. if !options.dedicatedCluster && len(options.dynamicConfigSettings) > 0 { for _, override := range options.dynamicConfigSettings {