@@ -21,6 +21,7 @@ import (
2121 "fmt"
2222 "io"
2323 "runtime/debug"
24+ "sync"
2425 "sync/atomic"
2526 "time"
2627
@@ -35,6 +36,7 @@ import (
3536 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes"
3637 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata"
3738 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status"
39+ "github.com/cloudwego/kitex/pkg/remote/trans/ttstream"
3840 "github.com/cloudwego/kitex/pkg/rpcinfo"
3941 "github.com/cloudwego/kitex/pkg/serviceinfo"
4042 "github.com/cloudwego/kitex/pkg/streaming"
@@ -216,6 +218,9 @@ type stream struct {
216218 streamingMode serviceinfo.StreamingMode
217219 finished uint32
218220 isGRPC bool
221+
222+ finishedErrOnce sync.Once
223+ finishedErr atomic.Value
219224}
220225
221226var (
@@ -254,6 +259,9 @@ func newStream(ctx context.Context, s streaming.ClientStream, scm *remotecli.Str
254259
255260// Header returns the header data sent by the server if any.
256261func (s * stream ) Header () (hd streaming.Header , err error ) {
262+ if atomic .LoadUint32 (& s .finished ) == 1 {
263+ return nil , s .finishedErr .Load ().(error )
264+ }
257265 if hd , err = s .ClientStream .Header (); err != nil {
258266 s .DoFinish (err )
259267 }
@@ -263,6 +271,9 @@ func (s *stream) Header() (hd streaming.Header, err error) {
263271// RecvMsg receives a message from the server.
264272// If an error is returned, stream.DoFinish() will be called to record the end of stream
265273func (s * stream ) RecvMsg (ctx context.Context , m interface {}) (err error ) {
274+ if atomic .LoadUint32 (& s .finished ) == 1 {
275+ return s .finishedErr .Load ().(error )
276+ }
266277 if ! s .recv .EqualsTo (recvEndpoint ) {
267278 // If the values are not equal, it indicates the presence of custom middleware.
268279 // To prevent errors caused by middleware code that relies on rpcinfo when users
@@ -293,19 +304,27 @@ func (s *stream) handleStreamRecvEvent(err error) {
293304}
294305
295306func (s * stream ) recvWithTimeout (ctx context.Context , m interface {}) error {
296- if ! s . isGRPC || s .recvTmCfg .Timeout <= 0 {
307+ if s .recvTmCfg .Timeout <= 0 {
297308 return s .recv (ctx , s .ClientStream , m )
298309 }
310+ buildTmErr := buildTTStreamRecvTimeoutErr
311+ buildPanicErr := buildTTStreamRecvPanicErr
312+ if s .isGRPC {
313+ buildTmErr = buildGRPCRecvTimeoutErr
314+ buildPanicErr = buildGRPCRecvPanicErr
315+ }
299316 return callWithTimeout (s .recvTmCfg ,
300317 func () error {
301318 return s .recv (ctx , s .ClientStream , m )
302319 },
303320 s .cancel ,
321+ buildTmErr ,
322+ buildPanicErr ,
304323 )
305324}
306325
307326func (s * stream ) cancel (err error ) {
308- // for now, only gRPC ClientStream implements CancelableClientStream interface
327+ // ClientStream of gRPC and ttstream both implements CancelableClientStream interface
309328 if c , ok := s .ClientStream .(internal_stream.CancelableClientStream ); ok {
310329 c .CancelWithErr (err )
311330 }
@@ -314,6 +333,9 @@ func (s *stream) cancel(err error) {
314333// SendMsg sends a message to the server.
315334// If an error is returned, stream.DoFinish() will be called to record the end of stream
316335func (s * stream ) SendMsg (ctx context.Context , m interface {}) (err error ) {
336+ if atomic .LoadUint32 (& s .finished ) == 1 {
337+ return s .finishedErr .Load ().(error )
338+ }
317339 if ! s .send .EqualsTo (sendEndpoint ) {
318340 // same with RecvMsg
319341 ri := rpcinfo .GetRPCInfo (ctx )
@@ -338,6 +360,17 @@ func (s *stream) handleStreamSendEvent(err error) {
338360// DoFinish implements the streaming.WithDoFinish interface, and it records the end of stream
339361// It will release the connection.
340362func (s * stream ) DoFinish (err error ) {
363+ s .finishedErrOnce .Do (func () {
364+ // store the finished err so that subsequent Recv/Send/Header calls would fail fast
365+ // and return the same finished err.
366+ // When err is nil (e.g. client streaming success), use io.EOF as the sentinel
367+ // since the stream is done and no more messages can be received.
368+ if err != nil {
369+ s .finishedErr .Store (err )
370+ } else {
371+ s .finishedErr .Store (io .EOF )
372+ }
373+ })
341374 if atomic .SwapUint32 (& s .finished , 1 ) == 1 {
342375 // already called
343376 return
@@ -413,6 +446,8 @@ func (s *grpcStream) recvWithTimeout(m interface{}) error {
413446 return s .recvEndpoint (s .Stream , m )
414447 },
415448 s .st .cancel ,
449+ buildGRPCRecvTimeoutErr ,
450+ buildGRPCRecvPanicErr ,
416451 )
417452}
418453
@@ -429,15 +464,19 @@ func (s *grpcStream) DoFinish(err error) {
429464 s .st .DoFinish (err )
430465}
431466
432- func callWithTimeout (tmCfg streaming.TimeoutConfig , call func () error , cancel func (error )) error {
467+ func callWithTimeout (tmCfg streaming.TimeoutConfig ,
468+ call func () error , cancel func (error ),
469+ buildTmErr func (streaming.TimeoutConfig ) error ,
470+ buildPanicErr func (interface {}, []byte ) error ,
471+ ) error {
433472 timer := time .NewTimer (tmCfg .Timeout )
434473 defer timer .Stop ()
435474 finishChan := make (chan error , 1 )
436475 gopool .Go (func () {
437476 var callErr error
438477 defer func () {
439478 if r := recover (); r != nil {
440- callErr = status . Errorf ( codes . Internal , "stream Recv panic, panic=%v, stack=%s" , r , debug .Stack ())
479+ callErr = buildPanicErr ( r , debug .Stack ())
441480 cancel (callErr )
442481 }
443482 finishChan <- callErr
@@ -446,7 +485,8 @@ func callWithTimeout(tmCfg streaming.TimeoutConfig, call func() error, cancel fu
446485 })
447486 select {
448487 case <- timer .C :
449- err := status .Errorf (codes .RecvDeadlineExceeded , recvTimeoutErrTpl , tmCfg )
488+ err := buildTmErr (tmCfg )
489+ // if DisableCancelRemote == true, users are responsible for ensuring that the stream is not leaked.
450490 if ! tmCfg .DisableCancelRemote {
451491 // finish the stream lifecycle so that the goroutine could exit
452492 cancel (err )
@@ -469,6 +509,22 @@ func isRPCError(err error) bool {
469509 return ! isBizStatusError
470510}
471511
512+ func buildGRPCRecvTimeoutErr (tmCfg streaming.TimeoutConfig ) error {
513+ return status .Errorf (codes .RecvDeadlineExceeded , recvTimeoutErrTpl , tmCfg )
514+ }
515+
516+ func buildGRPCRecvPanicErr (r interface {}, stack []byte ) error {
517+ return status .Errorf (codes .Internal , "stream Recv panic, panic=%v, stack=%s" , r , stack )
518+ }
519+
520+ func buildTTStreamRecvTimeoutErr (tmCfg streaming.TimeoutConfig ) error {
521+ return ttstream .NewStreamRecvTimeoutException (tmCfg , true )
522+ }
523+
524+ func buildTTStreamRecvPanicErr (r interface {}, stack []byte ) error {
525+ return ttstream .NewStreamInternalException (fmt .Sprintf ("stream Recv panic, panic=%v, stack=%s" , r , stack ), true )
526+ }
527+
472528var (
473529 recvEndpoint cep.StreamRecvEndpoint = func (ctx context.Context , stream streaming.ClientStream , m interface {}) error {
474530 return stream .RecvMsg (ctx , m )
0 commit comments