Skip to content

Commit 7018a57

Browse files
authored
fix(ttstream): use consistent context of stream in ttstream.RecvMsg and fix nil message error of binary generic (cloudwego#1866)
1 parent 4a7d8e3 commit 7018a57

4 files changed

Lines changed: 85 additions & 3 deletions

File tree

pkg/generic/thrift/raw.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ var _ MessageWriter = (*RawWriter)(nil)
4949

5050
// Write returns the copy of data
5151
func (m *RawWriter) Write(ctx context.Context, out bufiox.Writer, msg interface{}, method string, isClient bool, requestBase *base.Base) error {
52+
if msg == nil {
53+
bw := thrift.NewBufferWriter(out)
54+
defer bw.Recycle()
55+
if err := bw.WriteFieldStop(); err != nil {
56+
return err
57+
}
58+
return nil
59+
}
5260
buf, ok := msg.([]byte)
5361
if !ok {
5462
return fmt.Errorf("thrift binary generic msg is not []byte, method=%v", method)

pkg/generic/thrift/raw_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import (
2323

2424
"github.com/cloudwego/gopkg/bufiox"
2525
"github.com/cloudwego/gopkg/protocol/thrift"
26+
27+
"github.com/cloudwego/kitex/internal/test"
2628
)
2729

2830
func TestRawReader_Read(t *testing.T) {
@@ -46,3 +48,24 @@ func TestRawReader_Read(t *testing.T) {
4648
t.Fatalf("expect %v, got %v", nb[:off], data)
4749
}
4850
}
51+
52+
func TestRawWriter_Write(t *testing.T) {
53+
w := NewRawWriter()
54+
var buf []byte
55+
out := bufiox.NewBytesWriter(&buf)
56+
// nil message
57+
err := w.Write(context.Background(), out, nil, "method", true, nil)
58+
test.Assert(t, err == nil)
59+
err = out.Flush()
60+
test.Assert(t, err == nil)
61+
test.Assert(t, len(buf) == 1) // field stop
62+
test.Assert(t, buf[0] == byte(thrift.STOP))
63+
64+
// normal message
65+
buf = buf[:0]
66+
err = w.Write(context.Background(), out, []byte("hello world"), "method", true, nil)
67+
test.Assert(t, err == nil)
68+
err = out.Flush()
69+
test.Assert(t, err == nil)
70+
test.Assert(t, len(buf) == len("hello world"))
71+
}

pkg/remote/trans/ttstream/stream.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,17 @@ func (s *stream) SendMsg(ctx context.Context, msg any) (err error) {
125125
}
126126

127127
func (s *stream) RecvMsg(ctx context.Context, data any) error {
128+
nctx := s.ctx
128129
if s.recvTimeout > 0 {
129130
var cancel context.CancelFunc
130-
ctx, cancel = context.WithTimeout(ctx, s.recvTimeout)
131+
nctx, cancel = context.WithTimeout(nctx, s.recvTimeout)
131132
defer cancel()
132133
}
133-
payload, err := s.reader.output(ctx)
134+
payload, err := s.reader.output(nctx)
134135
if err != nil {
135136
return err
136137
}
137-
err = DecodePayload(ctx, payload, data)
138+
err = DecodePayload(nctx, payload, data)
138139
// payload will not be access after decode
139140
mcache.Free(payload)
140141

pkg/remote/trans/ttstream/stream_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ package ttstream
2020

2121
import (
2222
"context"
23+
"strings"
2324
"testing"
25+
"time"
2426

2527
"github.com/cloudwego/kitex/internal/test"
2628
)
@@ -54,3 +56,51 @@ func TestGenericStreaming(t *testing.T) {
5456
// test.Assert(t, res.A == req.A)
5557
// test.Assert(t, res.B == req.B)
5658
}
59+
60+
// TestStreamRecvTimeout tests that RecvMsg correctly handles timeout scenarios
61+
func TestStreamRecvTimeout(t *testing.T) {
62+
_, ss, err := newTestStreamPipe(testServiceInfo, "Bidi")
63+
test.Assert(t, err == nil, err)
64+
65+
// Set a very short timeout for testing
66+
ss.setRecvTimeout(time.Millisecond * 10)
67+
68+
// Create a context that won't expire
69+
ctx := context.Background()
70+
71+
// Try to receive message - should timeout quickly
72+
res := new(testResponse)
73+
err = ss.RecvMsg(ctx, res)
74+
test.Assert(t, err != nil, "RecvMsg should timeout")
75+
test.Assert(t, strings.Contains(err.Error(), "timeout") ||
76+
strings.Contains(err.Error(), "deadline exceeded"),
77+
"Error should be timeout related")
78+
}
79+
80+
// TestStreamRecvWithCancellation tests that RecvMsg respects context cancellation
81+
func TestStreamRecvWithCancellation(t *testing.T) {
82+
cs, ss, err := newTestStreamPipe(testServiceInfo, "Bidi")
83+
test.Assert(t, err == nil, err)
84+
85+
// Create a cancellable context with short timeout
86+
cancelCtx, cancel := context.WithCancel(context.Background())
87+
88+
cancel()
89+
90+
// Send message
91+
req := new(testRequest)
92+
req.A = 789
93+
req.B = "normal_test"
94+
95+
err = cs.SendMsg(cancelCtx, req)
96+
test.Assert(t, err == nil, err)
97+
98+
// Receive message
99+
res := new(testResponse)
100+
err = ss.RecvMsg(cancelCtx, res)
101+
test.Assert(t, err == nil, err)
102+
103+
// Verify content
104+
test.Assert(t, res.A == req.A)
105+
test.Assert(t, res.B == req.B)
106+
}

0 commit comments

Comments
 (0)