Skip to content

Commit 158a5c4

Browse files
committed
fix(gRPC): retrieve status or biz error for non-ServerStreaming
1 parent 7018a57 commit 158a5c4

4 files changed

Lines changed: 289 additions & 3 deletions

File tree

pkg/remote/codec/grpc/grpc.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/binary"
2222
"errors"
2323
"fmt"
24+
"io"
2425

2526
"github.com/cloudwego/fastpb"
2627

@@ -39,7 +40,10 @@ import (
3940

4041
const dataFrameHeaderLen = 5
4142

42-
var ErrInvalidPayload = errors.New("grpc invalid payload")
43+
var (
44+
ErrInvalidPayload = errors.New("grpc invalid payload")
45+
errWrongGRPCImplementation = errors.New("KITEX: grpc client streaming protocol violation: get <nil>, want <EOF>")
46+
)
4347

4448
// gogoproto generate
4549
type marshaler interface {
@@ -204,6 +208,24 @@ func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remot
204208
if err != nil {
205209
return err
206210
}
211+
// For ClientStreaming and Unary, server may return an err(e.g. status) as trailer frame.
212+
// We need to receive this trailer frame.
213+
if message.RPCRole() == remote.Client && isNonServerStreaming(message.RPCInfo().Invocation().StreamingMode()) {
214+
// Receive trailer frame
215+
// If err == nil, wrong gRPC protocol implementation.
216+
// If err == io.EOF, it means server returns nil, just ignore io.EOF.
217+
// If err != io.EOF, it means server returns status err or BizStatusErr, or other gRPC transport error came out,
218+
// we need to throw it to users.
219+
_, err = decodeGRPCFrame(ctx, in)
220+
if err == nil {
221+
return errWrongGRPCImplementation
222+
}
223+
if err != io.EOF {
224+
return err
225+
}
226+
// treat io.EOF as nil
227+
err = nil
228+
}
207229
message.SetPayloadLen(len(d))
208230
data := message.Data()
209231
switch message.RPCInfo().Config().PayloadCodec() {
@@ -255,3 +277,11 @@ func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remot
255277
func (c *grpcCodec) Name() string {
256278
return "grpc"
257279
}
280+
281+
func isNonServerStreaming(mode serviceinfo.StreamingMode) bool {
282+
if mode == serviceinfo.StreamingClient || mode == serviceinfo.StreamingUnary || mode == serviceinfo.StreamingNone {
283+
return true
284+
}
285+
// BidiStreaming has the ability of ServerStreaming, is also considered as ServerStreaming
286+
return false
287+
}

pkg/remote/codec/grpc/grpc_test.go

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
/*
2+
* Copyright 2025 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package grpc
18+
19+
import (
20+
"context"
21+
"io"
22+
"testing"
23+
24+
"github.com/golang/mock/gomock"
25+
26+
mocksremote "github.com/cloudwego/kitex/internal/mocks/remote"
27+
"github.com/cloudwego/kitex/internal/test"
28+
"github.com/cloudwego/kitex/pkg/kerrors"
29+
"github.com/cloudwego/kitex/pkg/remote"
30+
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes"
31+
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status"
32+
"github.com/cloudwego/kitex/pkg/rpcinfo"
33+
"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
34+
"github.com/cloudwego/kitex/pkg/serviceinfo"
35+
)
36+
37+
func Test_grpcCodec_Decode(t *testing.T) {
38+
codec := NewGRPCCodec()
39+
ctrl := gomock.NewController(t)
40+
defer ctrl.Finish()
41+
42+
testcases := []struct {
43+
desc string
44+
role remote.RPCRole
45+
mode serviceinfo.StreamingMode
46+
getByteBufferFunc func() remote.ByteBuffer
47+
expectErr error
48+
}{
49+
{
50+
desc: "client-side ClientStreaming decodes first grpc frame failed",
51+
role: remote.Client,
52+
mode: serviceinfo.StreamingClient,
53+
getByteBufferFunc: func() remote.ByteBuffer {
54+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
55+
mockIn.EXPECT().Next(5).Return(nil, status.Err(codes.Internal, "test")).Times(1)
56+
return mockIn
57+
},
58+
expectErr: status.Err(codes.Internal, "test"),
59+
},
60+
{
61+
desc: "client-side ClientStreaming decodes second grpc frame successfully => wrong gRPC protocol implementation on the server side",
62+
role: remote.Client,
63+
mode: serviceinfo.StreamingClient,
64+
getByteBufferFunc: func() remote.ByteBuffer {
65+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
66+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
67+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
68+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
69+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
70+
return mockIn
71+
},
72+
expectErr: errWrongGRPCImplementation,
73+
},
74+
{
75+
desc: "client-side ClientStreaming decodes second grpc frame getting io.EOF => normal exit on the server side",
76+
role: remote.Client,
77+
mode: serviceinfo.StreamingClient,
78+
getByteBufferFunc: func() remote.ByteBuffer {
79+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
80+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
81+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
82+
mockIn.EXPECT().Next(5).Return(nil, io.EOF).Times(1)
83+
return mockIn
84+
},
85+
expectErr: ErrInvalidPayload,
86+
},
87+
{
88+
desc: "client-side ClientStreaming decodes second grpc frame getting gRPC errors",
89+
role: remote.Client,
90+
mode: serviceinfo.StreamingClient,
91+
getByteBufferFunc: func() remote.ByteBuffer {
92+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
93+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
94+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
95+
mockIn.EXPECT().Next(5).Return(nil, status.Err(codes.Internal, "gRPC errors")).Times(1)
96+
return mockIn
97+
},
98+
expectErr: status.Err(codes.Internal, "gRPC errors"),
99+
},
100+
{
101+
desc: "client-side ClientStreaming decodes second grpc frame getting biz error",
102+
role: remote.Client,
103+
mode: serviceinfo.StreamingClient,
104+
getByteBufferFunc: func() remote.ByteBuffer {
105+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
106+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
107+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
108+
mockIn.EXPECT().Next(5).Return(nil, kerrors.NewGRPCBizStatusError(10000, "test")).Times(1)
109+
return mockIn
110+
},
111+
expectErr: kerrors.NewGRPCBizStatusError(10000, "test"),
112+
},
113+
{
114+
desc: "client-side Unary decodes second grpc frame getting io.EOF => normal exit on the server side",
115+
role: remote.Client,
116+
mode: serviceinfo.StreamingUnary,
117+
getByteBufferFunc: func() remote.ByteBuffer {
118+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
119+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
120+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
121+
mockIn.EXPECT().Next(5).Return(nil, io.EOF).Times(1)
122+
return mockIn
123+
},
124+
expectErr: ErrInvalidPayload,
125+
},
126+
{
127+
desc: "client-side Unary decodes second grpc frame getting biz error",
128+
role: remote.Client,
129+
mode: serviceinfo.StreamingUnary,
130+
getByteBufferFunc: func() remote.ByteBuffer {
131+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
132+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
133+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
134+
mockIn.EXPECT().Next(5).Return(nil, kerrors.NewGRPCBizStatusError(10000, "test")).Times(1)
135+
return mockIn
136+
},
137+
expectErr: kerrors.NewGRPCBizStatusError(10000, "test"),
138+
},
139+
{
140+
desc: "client-side None decodes second grpc frame getting io.EOF => normal exit on the server side",
141+
role: remote.Client,
142+
mode: serviceinfo.StreamingUnary,
143+
getByteBufferFunc: func() remote.ByteBuffer {
144+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
145+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
146+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
147+
mockIn.EXPECT().Next(5).Return(nil, io.EOF).Times(1)
148+
return mockIn
149+
},
150+
expectErr: ErrInvalidPayload,
151+
},
152+
{
153+
desc: "client-side None decodes second grpc frame getting biz error",
154+
role: remote.Client,
155+
mode: serviceinfo.StreamingNone,
156+
getByteBufferFunc: func() remote.ByteBuffer {
157+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
158+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
159+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
160+
mockIn.EXPECT().Next(5).Return(nil, kerrors.NewGRPCBizStatusError(10000, "test")).Times(1)
161+
return mockIn
162+
},
163+
expectErr: kerrors.NewGRPCBizStatusError(10000, "test"),
164+
},
165+
{
166+
desc: "client-side ServerStreaming decodes",
167+
role: remote.Client,
168+
mode: serviceinfo.StreamingServer,
169+
getByteBufferFunc: func() remote.ByteBuffer {
170+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
171+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
172+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
173+
return mockIn
174+
},
175+
expectErr: ErrInvalidPayload,
176+
},
177+
{
178+
desc: "client-side BidiStreaming decodes",
179+
role: remote.Client,
180+
mode: serviceinfo.StreamingBidirectional,
181+
getByteBufferFunc: func() remote.ByteBuffer {
182+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
183+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
184+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
185+
return mockIn
186+
},
187+
expectErr: ErrInvalidPayload,
188+
},
189+
{
190+
desc: "server-side decodes",
191+
role: remote.Server,
192+
getByteBufferFunc: func() remote.ByteBuffer {
193+
mockIn := mocksremote.NewMockByteBuffer(ctrl)
194+
mockIn.EXPECT().Next(5).Return([]byte{0, 0, 0, 0, 1}, nil).Times(1)
195+
mockIn.EXPECT().Next(1).Return([]byte{1}, nil).Times(1)
196+
return mockIn
197+
},
198+
expectErr: ErrInvalidPayload,
199+
},
200+
}
201+
mockServiceName := "grpcService"
202+
mockMethod := "InvokeClientStreaming"
203+
for _, tc := range testcases {
204+
t.Run(tc.desc, func(t *testing.T) {
205+
inv := rpcinfo.NewInvocation(mockServiceName, mockMethod)
206+
inv.SetStreamingMode(tc.mode)
207+
cfg := rpcinfo.NewRPCConfig()
208+
// avoid unmarshal
209+
rpcinfo.AsMutableRPCConfig(cfg).SetPayloadCodec(serviceinfo.PayloadCodec(-1))
210+
ri := rpcinfo.NewRPCInfo(
211+
rpcinfo.EmptyEndpointInfo(),
212+
remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{ServiceName: mockServiceName}, mockMethod).ImmutableView(),
213+
inv, cfg, rpcinfo.NewRPCStats())
214+
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
215+
mockMsg := remote.NewMessage(nil, ri, remote.Stream, tc.role)
216+
mockIn := tc.getByteBufferFunc()
217+
err := codec.Decode(ctx, mockMsg, mockIn)
218+
test.DeepEqual(t, err, tc.expectErr)
219+
})
220+
}
221+
}
222+
223+
func Test_isNonServerStreaming(t *testing.T) {
224+
testcases := []struct {
225+
mode serviceinfo.StreamingMode
226+
expectRes bool
227+
}{
228+
{
229+
mode: serviceinfo.StreamingNone,
230+
expectRes: true,
231+
},
232+
{
233+
mode: serviceinfo.StreamingUnary,
234+
expectRes: true,
235+
},
236+
{
237+
mode: serviceinfo.StreamingClient,
238+
expectRes: true,
239+
},
240+
{
241+
mode: serviceinfo.StreamingServer,
242+
expectRes: false,
243+
},
244+
{
245+
mode: serviceinfo.StreamingBidirectional,
246+
expectRes: false,
247+
},
248+
}
249+
250+
for _, tc := range testcases {
251+
test.Assert(t, isNonServerStreaming(tc.mode) == tc.expectRes)
252+
}
253+
}

pkg/remote/trans/nphttp2/server_handler_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ func TestServerHandler(t *testing.T) {
8888
msg.RPCInfoFunc = func() rpcinfo.RPCInfo {
8989
return ri
9090
}
91+
msg.RPCRoleFunc = func() remote.RPCRole {
92+
return remote.Server
93+
}
9194
npConn := newMockNpConn(mockAddr0)
9295
npConn.mockSettingFrame()
9396
tr, err := newMockServerTransport(npConn)

pkg/remote/trans/nphttp2/stream.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func (s *serverStream) SetTrailer(tl streaming.Trailer) error {
121121
func (s *serverStream) RecvMsg(ctx context.Context, m interface{}) error {
122122
ri := s.rpcInfo
123123

124-
msg := remote.NewMessage(m, ri, remote.Stream, remote.Client)
124+
msg := remote.NewMessage(m, ri, remote.Stream, remote.Server)
125125
defer msg.Recycle()
126126

127127
_, err := s.handler.Read(s.ctx, s.conn, msg)
@@ -133,7 +133,7 @@ func (s *serverStream) RecvMsg(ctx context.Context, m interface{}) error {
133133
func (s *serverStream) SendMsg(ctx context.Context, m interface{}) error {
134134
ri := s.rpcInfo
135135

136-
msg := remote.NewMessage(m, ri, remote.Stream, remote.Client)
136+
msg := remote.NewMessage(m, ri, remote.Stream, remote.Server)
137137
defer msg.Recycle()
138138

139139
_, err := s.handler.Write(s.ctx, s.conn, msg)

0 commit comments

Comments
 (0)