@@ -35,10 +35,12 @@ import (
3535 "github.com/cloudwego/kitex/pkg/event"
3636 "github.com/cloudwego/kitex/pkg/kerrors"
3737 "github.com/cloudwego/kitex/pkg/proxy"
38+ "github.com/cloudwego/kitex/pkg/remote"
3839 "github.com/cloudwego/kitex/pkg/remote/codec/protobuf"
3940 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status"
4041 "github.com/cloudwego/kitex/pkg/rpcinfo"
4142 "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
43+ "github.com/cloudwego/kitex/transport"
4244)
4345
4446var (
@@ -84,7 +86,7 @@ func TestResolverMW(t *testing.T) {
8486
8587 var invoked bool
8688 cli := newMockClient (t , ctrl ).(* kcFinalizerClient )
87- mw := newResolveMWBuilder (cli .lbf )(ctx )
89+ mw := newResolveMWBuilder (cli .lbf , nil )(ctx )
8890 ep := func (ctx context.Context , request , response interface {}) error {
8991 invoked = true
9092 return nil
@@ -114,14 +116,14 @@ func TestResolverMWOutOfInstance(t *testing.T) {
114116 }
115117 var invoked bool
116118 cli := newMockClient (t , ctrl , WithResolver (resolver )).(* kcFinalizerClient )
117- mw := newResolveMWBuilder (cli .lbf )(ctx )
119+ mw := newResolveMWBuilder (cli .lbf , nil )(ctx )
118120 ep := func (ctx context.Context , request , response interface {}) error {
119121 invoked = true
120122 return nil
121123 }
122124
123125 to := remoteinfo .NewRemoteInfo (& rpcinfo.EndpointBasicInfo {}, "" )
124- ri := rpcinfo .NewRPCInfo (nil , to , rpcinfo .NewInvocation ("" , "" ), nil , rpcinfo .NewRPCStats ())
126+ ri := rpcinfo .NewRPCInfo (nil , to , rpcinfo .NewInvocation ("" , "" ), rpcinfo . NewRPCConfig () , rpcinfo .NewRPCStats ())
125127
126128 ctx := rpcinfo .NewCtxWithRPCInfo (context .Background (), ri )
127129 req := new (MockTStruct )
@@ -222,7 +224,7 @@ func BenchmarkResolverMW(b *testing.B) {
222224 defer ctrl .Finish ()
223225
224226 cli := newMockClient (b , ctrl ).(* kcFinalizerClient )
225- mw := newResolveMWBuilder (cli .lbf )(ctx )
227+ mw := newResolveMWBuilder (cli .lbf , nil )(ctx )
226228 ep := func (ctx context.Context , request , response interface {}) error { return nil }
227229 ri := rpcinfo .NewRPCInfo (nil , nil , rpcinfo .NewInvocation ("" , "" ), nil , rpcinfo .NewRPCStats ())
228230
@@ -241,7 +243,7 @@ func BenchmarkResolverMWParallel(b *testing.B) {
241243 defer ctrl .Finish ()
242244
243245 cli := newMockClient (b , ctrl ).(* kcFinalizerClient )
244- mw := newResolveMWBuilder (cli .lbf )(ctx )
246+ mw := newResolveMWBuilder (cli .lbf , nil )(ctx )
245247 ep := func (ctx context.Context , request , response interface {}) error { return nil }
246248 ri := rpcinfo .NewRPCInfo (nil , nil , rpcinfo .NewInvocation ("" , "" ), nil , rpcinfo .NewRPCStats ())
247249
@@ -279,3 +281,170 @@ func TestDiscoveryEventHandler(t *testing.T) {
279281 added := extra ["Added" ].([]* instInfo )
280282 test .Assert (t , len (added ) == 1 )
281283}
284+
285+ // mockConnStatistics implements remote.ConnStatistics for testing
286+ type mockConnStatistics struct {
287+ activeStreams map [string ]int
288+ }
289+
290+ func (m * mockConnStatistics ) ActiveStreams (addr string ) int {
291+ if m .activeStreams == nil {
292+ return 0
293+ }
294+ return m .activeStreams [addr ]
295+ }
296+
297+ // TestResolverMW_WithConnStatistics_StreamingMode tests that ConnStatistics is passed to context
298+ // when in gRPC streaming mode
299+ func TestResolverMW_WithConnStatistics_StreamingMode (t * testing.T ) {
300+ ctrl := gomock .NewController (t )
301+ defer ctrl .Finish ()
302+
303+ mockStats := & mockConnStatistics {
304+ activeStreams : map [string ]int {
305+ "localhost:404" : 5 ,
306+ },
307+ }
308+
309+ var contextPassedToEndpoint context.Context
310+ cli := newMockClient (t , ctrl ).(* kcFinalizerClient )
311+ mw := newResolveMWBuilder (cli .lbf , mockStats )(ctx )
312+ ep := func (ctx context.Context , request , response interface {}) error {
313+ contextPassedToEndpoint = ctx
314+ return nil
315+ }
316+
317+ to := remoteinfo .NewRemoteInfo (& rpcinfo.EndpointBasicInfo {}, "" )
318+
319+ // Create RPC config with streaming mode and gRPC protocol
320+ cfg := rpcinfo .NewRPCConfig ()
321+ rpcinfo .AsMutableRPCConfig (cfg ).SetInteractionMode (rpcinfo .Streaming )
322+ rpcinfo .AsMutableRPCConfig (cfg ).SetTransportProtocol (transport .GRPC )
323+
324+ ri := rpcinfo .NewRPCInfo (nil , to , rpcinfo .NewInvocation ("" , "" ), cfg , rpcinfo .NewRPCStats ())
325+
326+ ctx := rpcinfo .NewCtxWithRPCInfo (context .Background (), ri )
327+ req := new (MockTStruct )
328+ res := new (MockTStruct )
329+ err := mw (ep )(ctx , req , res )
330+ test .Assert (t , err == nil )
331+ test .Assert (t , to .GetInstance () == instance404 [0 ])
332+
333+ // Verify ConnStatistics was passed to context
334+ cs := remote .GetConnStatistics (contextPassedToEndpoint )
335+ test .Assert (t , cs != nil , "ConnStatistics should be in context for streaming mode" )
336+ test .Assert (t , cs .ActiveStreams ("localhost:404" ) == 5 )
337+ }
338+
339+ // TestResolverMW_WithConnStatistics_NonStreamingMode tests that ConnStatistics is NOT passed
340+ // to context when not in streaming mode
341+ func TestResolverMW_WithConnStatistics_NonStreamingMode (t * testing.T ) {
342+ ctrl := gomock .NewController (t )
343+ defer ctrl .Finish ()
344+
345+ mockStats := & mockConnStatistics {
346+ activeStreams : map [string ]int {
347+ "localhost:404" : 5 ,
348+ },
349+ }
350+
351+ var contextPassedToEndpoint context.Context
352+ cli := newMockClient (t , ctrl ).(* kcFinalizerClient )
353+ mw := newResolveMWBuilder (cli .lbf , mockStats )(ctx )
354+ ep := func (ctx context.Context , request , response interface {}) error {
355+ contextPassedToEndpoint = ctx
356+ return nil
357+ }
358+
359+ to := remoteinfo .NewRemoteInfo (& rpcinfo.EndpointBasicInfo {}, "" )
360+
361+ // Create RPC config with PingPong mode (not streaming)
362+ cfg := rpcinfo .NewRPCConfig ()
363+ rpcinfo .AsMutableRPCConfig (cfg ).SetInteractionMode (rpcinfo .PingPong )
364+ rpcinfo .AsMutableRPCConfig (cfg ).SetTransportProtocol (transport .GRPC )
365+
366+ ri := rpcinfo .NewRPCInfo (nil , to , rpcinfo .NewInvocation ("" , "" ), cfg , rpcinfo .NewRPCStats ())
367+
368+ ctx := rpcinfo .NewCtxWithRPCInfo (context .Background (), ri )
369+ req := new (MockTStruct )
370+ res := new (MockTStruct )
371+ err := mw (ep )(ctx , req , res )
372+ test .Assert (t , err == nil )
373+
374+ // Verify ConnStatistics was NOT passed to context for non-streaming mode
375+ cs := remote .GetConnStatistics (contextPassedToEndpoint )
376+ test .Assert (t , cs == nil , "ConnStatistics should not be in context for non-streaming mode" )
377+ }
378+
379+ // TestResolverMW_WithConnStatistics_NonGRPC tests that ConnStatistics is NOT passed
380+ // for non-gRPC protocols
381+ func TestResolverMW_WithConnStatistics_NonGRPC (t * testing.T ) {
382+ ctrl := gomock .NewController (t )
383+ defer ctrl .Finish ()
384+
385+ mockStats := & mockConnStatistics {
386+ activeStreams : map [string ]int {
387+ "localhost:404" : 5 ,
388+ },
389+ }
390+
391+ var contextPassedToEndpoint context.Context
392+ cli := newMockClient (t , ctrl ).(* kcFinalizerClient )
393+ mw := newResolveMWBuilder (cli .lbf , mockStats )(ctx )
394+ ep := func (ctx context.Context , request , response interface {}) error {
395+ contextPassedToEndpoint = ctx
396+ return nil
397+ }
398+
399+ to := remoteinfo .NewRemoteInfo (& rpcinfo.EndpointBasicInfo {}, "" )
400+
401+ // Create RPC config with streaming mode but non-gRPC protocol
402+ cfg := rpcinfo .NewRPCConfig ()
403+ rpcinfo .AsMutableRPCConfig (cfg ).SetInteractionMode (rpcinfo .Streaming )
404+ rpcinfo .AsMutableRPCConfig (cfg ).SetTransportProtocol (transport .TTHeader ) // Not GRPC
405+
406+ ri := rpcinfo .NewRPCInfo (nil , to , rpcinfo .NewInvocation ("" , "" ), cfg , rpcinfo .NewRPCStats ())
407+
408+ ctx := rpcinfo .NewCtxWithRPCInfo (context .Background (), ri )
409+ req := new (MockTStruct )
410+ res := new (MockTStruct )
411+ err := mw (ep )(ctx , req , res )
412+ test .Assert (t , err == nil )
413+
414+ // Verify ConnStatistics was NOT passed for non-gRPC protocol
415+ cs := remote .GetConnStatistics (contextPassedToEndpoint )
416+ test .Assert (t , cs == nil , "ConnStatistics should not be in context for non-gRPC protocol" )
417+ }
418+
419+ // TestResolverMW_WithoutConnStatistics tests behavior when ConnStatistics is nil
420+ func TestResolverMW_WithoutConnStatistics (t * testing.T ) {
421+ ctrl := gomock .NewController (t )
422+ defer ctrl .Finish ()
423+
424+ var contextPassedToEndpoint context.Context
425+ cli := newMockClient (t , ctrl ).(* kcFinalizerClient )
426+ mw := newResolveMWBuilder (cli .lbf , nil )(ctx )
427+ ep := func (ctx context.Context , request , response interface {}) error {
428+ contextPassedToEndpoint = ctx
429+ return nil
430+ }
431+
432+ to := remoteinfo .NewRemoteInfo (& rpcinfo.EndpointBasicInfo {}, "" )
433+
434+ // Create RPC config with streaming mode and gRPC protocol
435+ cfg := rpcinfo .NewRPCConfig ()
436+ rpcinfo .AsMutableRPCConfig (cfg ).SetInteractionMode (rpcinfo .Streaming )
437+ rpcinfo .AsMutableRPCConfig (cfg ).SetTransportProtocol (transport .GRPC )
438+
439+ ri := rpcinfo .NewRPCInfo (nil , to , rpcinfo .NewInvocation ("" , "" ), cfg , rpcinfo .NewRPCStats ())
440+
441+ ctx := rpcinfo .NewCtxWithRPCInfo (context .Background (), ri )
442+ req := new (MockTStruct )
443+ res := new (MockTStruct )
444+ err := mw (ep )(ctx , req , res )
445+ test .Assert (t , err == nil )
446+
447+ // Verify ConnStatistics is nil when not provided
448+ cs := remote .GetConnStatistics (contextPassedToEndpoint )
449+ test .Assert (t , cs == nil , "ConnStatistics should be nil when not provided" )
450+ }
0 commit comments