-
-
Notifications
You must be signed in to change notification settings - Fork 584
Expand file tree
/
Copy pathstreaming_errors_test.go
More file actions
191 lines (166 loc) · 6.08 KB
/
streaming_errors_test.go
File metadata and controls
191 lines (166 loc) · 6.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
package codegen
import (
"strings"
"testing"
. "goa.design/goa/v3/dsl"
"goa.design/goa/v3/grpc/codegen/testdata"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestStreamingWithErrors tests that streaming endpoints properly handle
// custom errors defined in the service DSL.
func TestStreamingWithErrors(t *testing.T) {
cases := []struct {
name string
dsl func()
testFunc func(t *testing.T, code string)
}{
{
name: "server streaming with custom errors",
dsl: testdata.ServerStreamingWithCustomErrorsDSL,
testFunc: func(t *testing.T, code string) {
// Verify error decoding is present
assert.Contains(t, code, "goagrpc.DecodeError(err)",
"should decode errors from stream")
// Verify custom error types are handled
assert.Contains(t, code, "case *streaming_error_servicepb.ServerStreamCustomErrorError:",
"should handle custom error type")
assert.Contains(t, code, "case *streaming_error_servicepb.ServerStreamValidationErrorError:",
"should handle validation error type")
// Verify generic errors are handled
assert.Contains(t, code, "case *goapb.ErrorResponse:",
"should handle generic goa errors")
// Verify proper error construction
assert.Contains(t, code, "NewServerStreamCustomErrorError(message",
"should construct custom error")
assert.Contains(t, code, "NewServerStreamValidationErrorError(message",
"should construct validation error")
},
},
{
name: "bidirectional streaming with errors",
dsl: testdata.BidirectionalStreamingRPCWithErrorsDSL,
testFunc: func(t *testing.T, code string) {
// Bidirectional streaming with simple errors should still decode
assert.Contains(t, code, "goagrpc.DecodeError(err)",
"should decode errors from bidirectional stream")
assert.Contains(t, code, "case *goapb.ErrorResponse:",
"should handle generic errors in bidirectional streaming")
},
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
root := RunGRPCDSL(t, c.dsl)
services := CreateGRPCServices(root)
clientfs := ClientFiles("", services)
require.Greater(t, len(clientfs), 0)
// Get recv method implementations
recvSections := clientfs[0].Section("client-stream-recv")
require.Greater(t, len(recvSections), 0)
// Build complete recv method code
var codeBuilder strings.Builder
for _, section := range recvSections {
require.NoError(t, section.Write(&codeBuilder))
}
code := codeBuilder.String()
// Run test-specific assertions
c.testFunc(t, code)
})
}
}
// TestStreamingErrorsWithValidation verifies that custom errors with
// validation rules are properly validated in streaming recv methods.
func TestStreamingErrorsWithValidation(t *testing.T) {
root := RunGRPCDSL(t, testdata.ServerStreamingWithCustomErrorsDSL)
services := CreateGRPCServices(root)
// Verify the DSL has errors with validation
require.Len(t, root.Services, 1)
svc := root.Services[0]
require.Len(t, svc.Methods, 1)
method := svc.Methods[0]
require.Greater(t, len(method.Errors), 0, "method should have errors defined")
// Generate client code
clientfs := ClientFiles("", services)
require.Greater(t, len(clientfs), 0)
// Check recv implementations
recvSections := clientfs[0].Section("client-stream-recv")
var code strings.Builder
for _, section := range recvSections {
require.NoError(t, section.Write(&code))
}
recvCode := code.String()
// For errors with validation, verify validation is called
if strings.Contains(recvCode, "ValidateServerStreamCustomErrorError") {
assert.Contains(t, recvCode, "if err := ValidateServerStreamCustomErrorError(message); err != nil {",
"should validate custom error before returning")
}
}
// TestStreamingErrorComparison compares error handling between unary and
// streaming methods to ensure consistency.
func TestStreamingErrorComparison(t *testing.T) {
// DSL with both unary and streaming methods with errors
dsl := func() {
var CustomError = Type("CustomError", func() {
ErrorName("name", String, "error name")
Attribute("message", String, "error message")
Required("name", "message")
})
Service("MixedService", func() {
// Unary method with custom error
Method("UnaryMethod", func() {
Payload(String)
Result(String)
Error("custom_error", CustomError)
GRPC(func() {
Response("custom_error", CodeInvalidArgument)
})
})
// Streaming method with same error
Method("StreamingMethod", func() {
Payload(String)
StreamingResult(String)
Error("custom_error", CustomError)
GRPC(func() {
Response("custom_error", CodeInvalidArgument)
})
})
})
}
root := RunGRPCDSL(t, dsl)
services := CreateGRPCServices(root)
clientfs := ClientFiles("", services)
require.Greater(t, len(clientfs), 0, "should have client files")
// Find unary and streaming code in different sections
var unaryCode, streamCode string
// For unary, look in client-endpoint-init
if sections := clientfs[0].Section("client-endpoint-init"); len(sections) > 0 {
var code strings.Builder
for _, section := range sections {
require.NoError(t, section.Write(&code))
}
unaryCode = code.String()
}
// For streaming, look in client-stream-recv
if sections := clientfs[0].Section("client-stream-recv"); len(sections) > 0 {
var code strings.Builder
for _, section := range sections {
require.NoError(t, section.Write(&code))
}
streamCode = code.String()
}
// If no sections found, skip test with explanation
if unaryCode == "" || streamCode == "" {
t.Skip("Cannot compare unary and streaming - sections not found in generated code")
}
// Both should decode errors
assert.Contains(t, unaryCode, "goagrpc.DecodeError(err)",
"unary methods should decode errors")
assert.Contains(t, streamCode, "goagrpc.DecodeError(err)",
"streaming methods should decode errors")
// Both should handle the custom error type
assert.Contains(t, unaryCode, "case *mixed_servicepb.",
"unary should handle custom error types")
assert.Contains(t, streamCode, "case *mixed_servicepb.",
"streaming should handle custom error types")
}