1+ package codegen
2+
3+ import (
4+ "strings"
5+ "testing"
6+
7+ . "goa.design/goa/v3/dsl"
8+ "goa.design/goa/v3/grpc/codegen/testdata"
9+ "github.com/stretchr/testify/assert"
10+ "github.com/stretchr/testify/require"
11+ )
12+
13+ // TestStreamingWithErrors tests that streaming endpoints properly handle
14+ // custom errors defined in the service DSL.
15+ func TestStreamingWithErrors (t * testing.T ) {
16+ cases := []struct {
17+ name string
18+ dsl func ()
19+ testFunc func (t * testing.T , code string )
20+ }{
21+ {
22+ name : "server streaming with custom errors" ,
23+ dsl : testdata .ServerStreamingWithCustomErrorsDSL ,
24+ testFunc : func (t * testing.T , code string ) {
25+ // Verify error decoding is present
26+ assert .Contains (t , code , "goagrpc.DecodeError(err)" ,
27+ "should decode errors from stream" )
28+
29+ // Verify custom error types are handled
30+ assert .Contains (t , code , "case *streaming_error_servicepb.ServerStreamCustomErrorError:" ,
31+ "should handle custom error type" )
32+ assert .Contains (t , code , "case *streaming_error_servicepb.ServerStreamValidationErrorError:" ,
33+ "should handle validation error type" )
34+
35+ // Verify generic errors are handled
36+ assert .Contains (t , code , "case *goapb.ErrorResponse:" ,
37+ "should handle generic goa errors" )
38+
39+ // Verify proper error construction
40+ assert .Contains (t , code , "NewServerStreamCustomErrorError(message" ,
41+ "should construct custom error" )
42+ assert .Contains (t , code , "NewServerStreamValidationErrorError(message" ,
43+ "should construct validation error" )
44+ },
45+ },
46+ {
47+ name : "bidirectional streaming with errors" ,
48+ dsl : testdata .BidirectionalStreamingRPCWithErrorsDSL ,
49+ testFunc : func (t * testing.T , code string ) {
50+ // Bidirectional streaming with simple errors should still decode
51+ assert .Contains (t , code , "goagrpc.DecodeError(err)" ,
52+ "should decode errors from bidirectional stream" )
53+ assert .Contains (t , code , "case *goapb.ErrorResponse:" ,
54+ "should handle generic errors in bidirectional streaming" )
55+ },
56+ },
57+ }
58+
59+ for _ , c := range cases {
60+ t .Run (c .name , func (t * testing.T ) {
61+ root := RunGRPCDSL (t , c .dsl )
62+ services := CreateGRPCServices (root )
63+ clientfs := ClientFiles ("" , services )
64+ require .Greater (t , len (clientfs ), 0 )
65+
66+ // Get recv method implementations
67+ recvSections := clientfs [0 ].Section ("client-stream-recv" )
68+ require .Greater (t , len (recvSections ), 0 )
69+
70+ // Build complete recv method code
71+ var codeBuilder strings.Builder
72+ for _ , section := range recvSections {
73+ require .NoError (t , section .Write (& codeBuilder ))
74+ }
75+ code := codeBuilder .String ()
76+
77+ // Run test-specific assertions
78+ c .testFunc (t , code )
79+ })
80+ }
81+ }
82+
83+ // TestStreamingErrorsWithValidation verifies that custom errors with
84+ // validation rules are properly validated in streaming recv methods.
85+ func TestStreamingErrorsWithValidation (t * testing.T ) {
86+ root := RunGRPCDSL (t , testdata .ServerStreamingWithCustomErrorsDSL )
87+ services := CreateGRPCServices (root )
88+
89+ // Verify the DSL has errors with validation
90+ require .Len (t , root .Services , 1 )
91+ svc := root .Services [0 ]
92+ require .Len (t , svc .Methods , 1 )
93+ method := svc .Methods [0 ]
94+ require .Greater (t , len (method .Errors ), 0 , "method should have errors defined" )
95+
96+ // Generate client code
97+ clientfs := ClientFiles ("" , services )
98+ require .Greater (t , len (clientfs ), 0 )
99+
100+ // Check recv implementations
101+ recvSections := clientfs [0 ].Section ("client-stream-recv" )
102+ var code strings.Builder
103+ for _ , section := range recvSections {
104+ require .NoError (t , section .Write (& code ))
105+ }
106+ recvCode := code .String ()
107+
108+ // For errors with validation, verify validation is called
109+ if strings .Contains (recvCode , "ValidateServerStreamCustomErrorError" ) {
110+ assert .Contains (t , recvCode , "if err := ValidateServerStreamCustomErrorError(message); err != nil {" ,
111+ "should validate custom error before returning" )
112+ }
113+ }
114+
115+ // TestStreamingErrorComparison compares error handling between unary and
116+ // streaming methods to ensure consistency.
117+ func TestStreamingErrorComparison (t * testing.T ) {
118+ // DSL with both unary and streaming methods with errors
119+ dsl := func () {
120+ var CustomError = Type ("CustomError" , func () {
121+ ErrorName ("name" , String , "error name" )
122+ Attribute ("message" , String , "error message" )
123+ Required ("name" , "message" )
124+ })
125+
126+ Service ("MixedService" , func () {
127+ // Unary method with custom error
128+ Method ("UnaryMethod" , func () {
129+ Payload (String )
130+ Result (String )
131+ Error ("custom_error" , CustomError )
132+ GRPC (func () {
133+ Response ("custom_error" , CodeInvalidArgument )
134+ })
135+ })
136+
137+ // Streaming method with same error
138+ Method ("StreamingMethod" , func () {
139+ Payload (String )
140+ StreamingResult (String )
141+ Error ("custom_error" , CustomError )
142+ GRPC (func () {
143+ Response ("custom_error" , CodeInvalidArgument )
144+ })
145+ })
146+ })
147+ }
148+
149+ root := RunGRPCDSL (t , dsl )
150+ services := CreateGRPCServices (root )
151+ clientfs := ClientFiles ("" , services )
152+ require .Greater (t , len (clientfs ), 0 , "should have client files" )
153+
154+ // Find unary and streaming code in different sections
155+ var unaryCode , streamCode string
156+
157+ // For unary, look in client-endpoint-init
158+ if sections := clientfs [0 ].Section ("client-endpoint-init" ); len (sections ) > 0 {
159+ var code strings.Builder
160+ for _ , section := range sections {
161+ require .NoError (t , section .Write (& code ))
162+ }
163+ unaryCode = code .String ()
164+ }
165+
166+ // For streaming, look in client-stream-recv
167+ if sections := clientfs [0 ].Section ("client-stream-recv" ); len (sections ) > 0 {
168+ var code strings.Builder
169+ for _ , section := range sections {
170+ require .NoError (t , section .Write (& code ))
171+ }
172+ streamCode = code .String ()
173+ }
174+
175+ // If no sections found, skip test with explanation
176+ if unaryCode == "" || streamCode == "" {
177+ t .Skip ("Cannot compare unary and streaming - sections not found in generated code" )
178+ }
179+
180+ // Both should decode errors
181+ assert .Contains (t , unaryCode , "goagrpc.DecodeError(err)" ,
182+ "unary methods should decode errors" )
183+ assert .Contains (t , streamCode , "goagrpc.DecodeError(err)" ,
184+ "streaming methods should decode errors" )
185+
186+ // Both should handle the custom error type
187+ assert .Contains (t , unaryCode , "case *mixed_servicepb." ,
188+ "unary should handle custom error types" )
189+ assert .Contains (t , streamCode , "case *mixed_servicepb." ,
190+ "streaming should handle custom error types" )
191+ }
0 commit comments