Skip to content

Commit 6cad7df

Browse files
committed
Update nexus_test_base.go to use a testcore.TestEnv base
1 parent f87611a commit 6cad7df

2 files changed

Lines changed: 56 additions & 29 deletions

File tree

tests/nexus_test_base.go

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package tests
33
import (
44
"context"
55
"errors"
6+
"testing"
67

78
"github.com/google/uuid"
89
"github.com/nexus-rpc/sdk-go/nexus"
10+
"github.com/stretchr/testify/require"
911
commonpb "go.temporal.io/api/common/v1"
1012
enumspb "go.temporal.io/api/enums/v1"
1113
nexuspb "go.temporal.io/api/nexus/v1"
@@ -15,29 +17,44 @@ import (
1517
"go.temporal.io/api/workflowservice/v1"
1618
cnexus "go.temporal.io/server/common/nexus"
1719
"go.temporal.io/server/common/nexus/nexusrpc"
20+
"go.temporal.io/server/components/nexusoperations"
1821
"go.temporal.io/server/tests/testcore"
1922
)
2023

21-
type NexusTestBaseSuite struct {
22-
testcore.FunctionalTestBase
24+
type NexusTestEnv struct {
25+
*testcore.TestEnv
2326
useTemporalFailures bool
2427
}
2528

26-
func (s *NexusTestBaseSuite) createNexusEndpoint(name string, taskQueue string) *nexuspb.Endpoint {
27-
resp, err := s.OperatorClient().CreateNexusEndpoint(testcore.NewContext(), &operatorservice.CreateNexusEndpointRequest{
29+
func newNexusTestEnv(t *testing.T, useTemporalFailures bool, opts ...testcore.TestOption) *NexusTestEnv {
30+
env := &NexusTestEnv{
31+
TestEnv: testcore.NewEnv(t, opts...),
32+
useTemporalFailures: useTemporalFailures,
33+
}
34+
if testcore.IsDedicatedCluster(opts) {
35+
env.OverrideDynamicConfig(
36+
nexusoperations.CallbackURLTemplate,
37+
"http://"+env.HttpAPIAddress()+"/nexus/callback",
38+
)
39+
}
40+
return env
41+
}
42+
43+
func (env *NexusTestEnv) createNexusEndpoint(name string, taskQueue string, t *testing.T) *nexuspb.Endpoint {
44+
resp, err := env.OperatorClient().CreateNexusEndpoint(testcore.NewContext(), &operatorservice.CreateNexusEndpointRequest{
2845
Spec: &nexuspb.EndpointSpec{
2946
Name: name,
3047
Target: &nexuspb.EndpointTarget{
3148
Variant: &nexuspb.EndpointTarget_Worker_{
3249
Worker: &nexuspb.EndpointTarget_Worker{
33-
Namespace: s.Namespace().String(),
50+
Namespace: env.Namespace().String(),
3451
TaskQueue: taskQueue,
3552
},
3653
},
3754
},
3855
},
3956
})
40-
s.NoError(err)
57+
require.NoError(t, err)
4158
return resp.Endpoint
4259
}
4360

@@ -56,28 +73,28 @@ type nexusTaskResponse struct {
5673

5774
type nexusTaskHandler func(res *workflowservice.PollNexusTaskQueueResponse) (*nexusTaskResponse, error)
5875

59-
func (s *NexusTestBaseSuite) nexusTaskPoller(ctx context.Context, taskQueue string, handler nexusTaskHandler) <-chan error {
60-
return s.versionedNexusTaskPoller(ctx, taskQueue, "", handler)
76+
func (env *NexusTestEnv) nexusTaskPoller(ctx context.Context, taskQueue string, handler nexusTaskHandler) <-chan error {
77+
return env.versionedNexusTaskPoller(ctx, taskQueue, "", handler)
6178
}
6279

63-
func (s *NexusTestBaseSuite) versionedNexusTaskPoller(ctx context.Context, taskQueue, buildID string, handler nexusTaskHandler) <-chan error {
80+
func (env *NexusTestEnv) versionedNexusTaskPoller(ctx context.Context, taskQueue, buildID string, handler nexusTaskHandler) <-chan error {
6481
errCh := make(chan error, 1)
6582
go func() {
66-
errCh <- s.versionedNexusTaskPollerDo(ctx, taskQueue, buildID, handler)
83+
errCh <- env.versionedNexusTaskPollerDo(ctx, taskQueue, buildID, handler)
6784
}()
6885
return errCh
6986
}
7087

71-
func (s *NexusTestBaseSuite) versionedNexusTaskPollerDo(ctx context.Context, taskQueue, buildID string, handler nexusTaskHandler) error {
88+
func (env *NexusTestEnv) versionedNexusTaskPollerDo(ctx context.Context, taskQueue, buildID string, handler nexusTaskHandler) error {
7289
var vc *commonpb.WorkerVersionCapabilities
7390
if buildID != "" {
7491
vc = &commonpb.WorkerVersionCapabilities{
7592
BuildId: buildID,
7693
UseVersioning: true,
7794
}
7895
}
79-
res, err := s.GetTestCluster().FrontendClient().PollNexusTaskQueue(ctx, &workflowservice.PollNexusTaskQueueRequest{
80-
Namespace: s.Namespace().String(),
96+
res, err := env.GetTestCluster().FrontendClient().PollNexusTaskQueue(ctx, &workflowservice.PollNexusTaskQueueRequest{
97+
Namespace: env.Namespace().String(),
8198
Identity: uuid.NewString(),
8299
TaskQueue: &taskqueuepb.TaskQueue{
83100
Name: taskQueue,
@@ -103,9 +120,9 @@ func (s *NexusTestBaseSuite) versionedNexusTaskPollerDo(ctx context.Context, tas
103120
var opErr *nexus.OperationError
104121
var he *nexus.HandlerError
105122
if errors.As(handlerErr, &opErr) {
106-
return s.respondNexusTaskCompletedWithOperationError(ctx, res.TaskToken, opErr)
123+
return env.respondNexusTaskCompletedWithOperationError(ctx, res.TaskToken, opErr)
107124
} else if errors.As(handlerErr, &he) {
108-
return s.respondNexusTaskFailed(ctx, res.TaskToken, he)
125+
return env.respondNexusTaskFailed(ctx, res.TaskToken, he)
109126
}
110127
return handlerErr
111128
}
@@ -163,8 +180,8 @@ func (s *NexusTestBaseSuite) versionedNexusTaskPollerDo(ctx context.Context, tas
163180
panic("unreachable") // nolint:revive // all implementations of HandlerStartOperationResult must be covered here, so this should be unreachable.
164181
}
165182
}
166-
_, err = s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{
167-
Namespace: s.Namespace().String(),
183+
_, err = env.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{
184+
Namespace: env.Namespace().String(),
168185
Identity: uuid.NewString(),
169186
TaskToken: res.TaskToken,
170187
Response: response,
@@ -175,8 +192,8 @@ func (s *NexusTestBaseSuite) versionedNexusTaskPollerDo(ctx context.Context, tas
175192
return nil
176193
}
177194

178-
func (s *NexusTestBaseSuite) respondNexusTaskFailed(ctx context.Context, taskToken []byte, he *nexus.HandlerError) error {
179-
if s.useTemporalFailures {
195+
func (env *NexusTestEnv) respondNexusTaskFailed(ctx context.Context, taskToken []byte, he *nexus.HandlerError) error {
196+
if env.useTemporalFailures {
180197
nexusFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(he)
181198
if err != nil {
182199
return err
@@ -185,8 +202,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskFailed(ctx context.Context, taskTok
185202
if err != nil {
186203
return err
187204
}
188-
_, err = s.GetTestCluster().FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{
189-
Namespace: s.Namespace().String(),
205+
_, err = env.GetTestCluster().FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{
206+
Namespace: env.Namespace().String(),
190207
Identity: uuid.NewString(),
191208
TaskToken: taskToken,
192209
Failure: temporalFailure,
@@ -219,8 +236,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskFailed(ctx context.Context, taskTok
219236
protoError.RetryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE
220237
default:
221238
}
222-
_, err := s.GetTestCluster().FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{
223-
Namespace: s.Namespace().String(),
239+
_, err := env.GetTestCluster().FrontendClient().RespondNexusTaskFailed(ctx, &workflowservice.RespondNexusTaskFailedRequest{
240+
Namespace: env.Namespace().String(),
224241
Identity: uuid.NewString(),
225242
TaskToken: taskToken,
226243
Error: protoError,
@@ -231,8 +248,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskFailed(ctx context.Context, taskTok
231248
return nil
232249
}
233250

234-
func (s *NexusTestBaseSuite) respondNexusTaskCompletedWithOperationError(ctx context.Context, taskToken []byte, opErr *nexus.OperationError) error {
235-
if s.useTemporalFailures {
251+
func (env *NexusTestEnv) respondNexusTaskCompletedWithOperationError(ctx context.Context, taskToken []byte, opErr *nexus.OperationError) error {
252+
if env.useTemporalFailures {
236253
nexusFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(opErr)
237254
if err != nil {
238255
return err
@@ -250,8 +267,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskCompletedWithOperationError(ctx con
250267
},
251268
},
252269
}
253-
_, err = s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{
254-
Namespace: s.Namespace().String(),
270+
_, err = env.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{
271+
Namespace: env.Namespace().String(),
255272
Identity: uuid.NewString(),
256273
TaskToken: taskToken,
257274
Response: response,
@@ -284,8 +301,8 @@ func (s *NexusTestBaseSuite) respondNexusTaskCompletedWithOperationError(ctx con
284301
},
285302
},
286303
}
287-
_, err := s.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{
288-
Namespace: s.Namespace().String(),
304+
_, err := env.GetTestCluster().FrontendClient().RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{
305+
Namespace: env.Namespace().String(),
289306
Identity: uuid.NewString(),
290307
TaskToken: taskToken,
291308
Response: response,

tests/testcore/test_env.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,13 @@ func checkTestShard(t *testing.T) {
358358
}
359359
t.Logf("Running %s in test shard %d/%d", t.Name(), index+1, total)
360360
}
361+
362+
// IsDedicatedCluster allows caller to inspect whether opts contains
363+
// testcore.WithDedicatedCluster().
364+
func IsDedicatedCluster(opts []TestOption) bool {
365+
var o testOptions
366+
for _, opt := range opts {
367+
opt(&o)
368+
}
369+
return o.dedicatedCluster
370+
}

0 commit comments

Comments
 (0)