@@ -16,7 +16,7 @@ use std::pin::Pin;
1616use std:: task:: { Context , Poll } ;
1717
1818use bytes:: { BufMut , Bytes , BytesMut } ;
19- use hyper:: body:: { Frame , Incoming } ;
19+ use hyper:: body:: Frame ;
2020use hyper:: header:: HeaderValue ;
2121use hyper:: http:: HeaderMap ;
2222use hyper:: { Request , Response } ;
@@ -255,7 +255,7 @@ pub(crate) fn grpc_response(body: GrpcBody) -> Response<GrpcBody> {
255255}
256256
257257/// Validate that the request looks like a gRPC call.
258- pub ( crate ) fn validate_grpc_request ( req : & Request < Incoming > ) -> Result < ( ) , GrpcStatus > {
258+ pub ( crate ) fn validate_grpc_request < B > ( req : & Request < B > ) -> Result < ( ) , GrpcStatus > {
259259 if req. method ( ) != hyper:: Method :: POST {
260260 return Err ( GrpcStatus :: new ( GRPC_STATUS_UNIMPLEMENTED , "gRPC requires POST method" ) ) ;
261261 }
@@ -445,4 +445,174 @@ mod tests {
445445 assert_eq ! ( parse_grpc_timeout( "S" ) , None ) ;
446446 assert_eq ! ( parse_grpc_timeout( "5x" ) , None ) ;
447447 }
448+
449+ #[ test]
450+ fn test_validate_grpc_request_valid ( ) {
451+ let req = Request :: builder ( )
452+ . method ( "POST" )
453+ . header ( "content-type" , "application/grpc" )
454+ . body ( ( ) )
455+ . unwrap ( ) ;
456+ assert ! ( validate_grpc_request( & req) . is_ok( ) ) ;
457+
458+ let req = Request :: builder ( )
459+ . method ( "POST" )
460+ . header ( "content-type" , "application/grpc+proto" )
461+ . body ( ( ) )
462+ . unwrap ( ) ;
463+ assert ! ( validate_grpc_request( & req) . is_ok( ) ) ;
464+ }
465+
466+ #[ test]
467+ fn test_validate_grpc_request_wrong_method ( ) {
468+ let req = Request :: builder ( )
469+ . method ( "GET" )
470+ . header ( "content-type" , "application/grpc" )
471+ . body ( ( ) )
472+ . unwrap ( ) ;
473+ let err = validate_grpc_request ( & req) . unwrap_err ( ) ;
474+ assert_eq ! ( err. code, GRPC_STATUS_UNIMPLEMENTED ) ;
475+ }
476+
477+ #[ test]
478+ fn test_validate_grpc_request_wrong_content_type ( ) {
479+ let cases =
480+ [ "application/json" , "application/grpc+json" , "application/grpcfoo" , "text/plain" , "" ] ;
481+ for ct in & cases {
482+ let req =
483+ Request :: builder ( ) . method ( "POST" ) . header ( "content-type" , * ct) . body ( ( ) ) . unwrap ( ) ;
484+ let err = validate_grpc_request ( & req) . unwrap_err ( ) ;
485+ assert_eq ! ( err. code, GRPC_STATUS_INVALID_ARGUMENT , "should reject content-type {ct:?}" ) ;
486+ }
487+ }
488+
489+ #[ test]
490+ fn test_validate_grpc_request_missing_content_type ( ) {
491+ let req = Request :: builder ( ) . method ( "POST" ) . body ( ( ) ) . unwrap ( ) ;
492+ let err = validate_grpc_request ( & req) . unwrap_err ( ) ;
493+ assert_eq ! ( err. code, GRPC_STATUS_INVALID_ARGUMENT ) ;
494+ }
495+
496+ #[ test]
497+ fn test_ldk_error_to_grpc_status_all_variants ( ) {
498+ let cases = [
499+ ( LdkServerErrorCode :: InvalidRequestError , GRPC_STATUS_INVALID_ARGUMENT ) ,
500+ ( LdkServerErrorCode :: AuthError , GRPC_STATUS_UNAUTHENTICATED ) ,
501+ ( LdkServerErrorCode :: LightningError , GRPC_STATUS_FAILED_PRECONDITION ) ,
502+ ( LdkServerErrorCode :: InternalServerError , GRPC_STATUS_INTERNAL ) ,
503+ ] ;
504+ for ( error_code, expected_grpc_code) in cases {
505+ let e = LdkServerError :: new ( error_code, "test message" ) ;
506+ let s = ldk_error_to_grpc_status ( e) ;
507+ assert_eq ! ( s. code, expected_grpc_code) ;
508+ assert_eq ! ( s. message, "test message" ) ;
509+ }
510+ }
511+
512+ #[ test]
513+ fn test_grpc_error_response_structure ( ) {
514+ let status = GrpcStatus :: new ( GRPC_STATUS_INTERNAL , "something broke" ) ;
515+ let resp = grpc_error_response ( status) ;
516+
517+ assert_eq ! ( resp. status( ) , 200 ) ;
518+ assert_eq ! ( resp. headers( ) . get( "content-type" ) . unwrap( ) , GRPC_CONTENT_TYPE ) ;
519+ assert_eq ! (
520+ resp. headers( ) . get( GRPC_ACCEPT_ENCODING_HEADER ) . unwrap( ) ,
521+ GRPC_ENCODING_IDENTITY
522+ ) ;
523+ assert_eq ! (
524+ resp. headers( ) . get( GRPC_STATUS_HEADER ) . unwrap( ) ,
525+ & GRPC_STATUS_INTERNAL . to_string( )
526+ ) ;
527+ assert_eq ! ( resp. headers( ) . get( GRPC_MESSAGE_HEADER ) . unwrap( ) , "something broke" ) ;
528+ }
529+
530+ #[ test]
531+ fn test_grpc_error_response_empty_message_omits_grpc_message ( ) {
532+ let status = GrpcStatus :: new ( GRPC_STATUS_UNAUTHENTICATED , "" ) ;
533+ let resp = grpc_error_response ( status) ;
534+
535+ assert_eq ! ( resp. headers( ) . get( GRPC_STATUS_HEADER ) . unwrap( ) , "16" ) ;
536+ assert ! ( resp. headers( ) . get( GRPC_MESSAGE_HEADER ) . is_none( ) ) ;
537+ }
538+
539+ #[ test]
540+ fn test_grpc_response_headers ( ) {
541+ let body = GrpcBody :: Unary { data : Some ( Bytes :: new ( ) ) , trailers_sent : false } ;
542+ let resp = grpc_response ( body) ;
543+
544+ assert_eq ! ( resp. status( ) , 200 ) ;
545+ assert_eq ! ( resp. headers( ) . get( "content-type" ) . unwrap( ) , GRPC_CONTENT_TYPE ) ;
546+ assert_eq ! (
547+ resp. headers( ) . get( GRPC_ACCEPT_ENCODING_HEADER ) . unwrap( ) ,
548+ GRPC_ENCODING_IDENTITY
549+ ) ;
550+ }
551+
552+ #[ test]
553+ fn test_grpc_body_unary_poll_sequence ( ) {
554+ use hyper:: body:: Body ;
555+ use std:: task:: { Context , Poll , RawWaker , RawWakerVTable , Waker } ;
556+
557+ fn noop_waker ( ) -> Waker {
558+ fn no_op ( _: * const ( ) ) { }
559+ fn clone ( p : * const ( ) ) -> RawWaker {
560+ RawWaker :: new ( p, & VTABLE )
561+ }
562+ static VTABLE : RawWakerVTable = RawWakerVTable :: new ( clone, no_op, no_op, no_op) ;
563+ unsafe { Waker :: from_raw ( RawWaker :: new ( std:: ptr:: null ( ) , & VTABLE ) ) }
564+ }
565+
566+ let payload = encode_grpc_frame ( b"test" ) ;
567+ let mut body = GrpcBody :: Unary { data : Some ( payload. clone ( ) ) , trailers_sent : false } ;
568+
569+ let waker = noop_waker ( ) ;
570+ let mut cx = Context :: from_waker ( & waker) ;
571+
572+ // First poll: data frame
573+ let frame = Pin :: new ( & mut body) . poll_frame ( & mut cx) ;
574+ match frame {
575+ Poll :: Ready ( Some ( Ok ( ref f) ) ) => assert ! ( f. is_data( ) ) ,
576+ ref other => panic ! ( "expected data frame, got {other:?}" ) ,
577+ }
578+
579+ // Second poll: trailers
580+ let frame = Pin :: new ( & mut body) . poll_frame ( & mut cx) ;
581+ match frame {
582+ Poll :: Ready ( Some ( Ok ( f) ) ) => {
583+ let trailers = f. into_trailers ( ) . expect ( "expected trailers" ) ;
584+ assert_eq ! ( trailers. get( GRPC_STATUS_HEADER ) . unwrap( ) , "0" ) ;
585+ } ,
586+ ref other => panic ! ( "expected trailers frame, got {other:?}" ) ,
587+ }
588+
589+ // Third poll: end of stream
590+ let frame: Poll < Option < Result < hyper:: body:: Frame < Bytes > , hyper:: Error > > > =
591+ Pin :: new ( & mut body) . poll_frame ( & mut cx) ;
592+ assert ! ( matches!( frame, Poll :: Ready ( None ) ) ) ;
593+ }
594+
595+ #[ test]
596+ fn test_grpc_body_empty_poll ( ) {
597+ use hyper:: body:: Body ;
598+ use std:: task:: { Context , Poll , RawWaker , RawWakerVTable , Waker } ;
599+
600+ fn noop_waker ( ) -> Waker {
601+ fn no_op ( _: * const ( ) ) { }
602+ fn clone ( p : * const ( ) ) -> RawWaker {
603+ RawWaker :: new ( p, & VTABLE )
604+ }
605+ static VTABLE : RawWakerVTable = RawWakerVTable :: new ( clone, no_op, no_op, no_op) ;
606+ unsafe { Waker :: from_raw ( RawWaker :: new ( std:: ptr:: null ( ) , & VTABLE ) ) }
607+ }
608+
609+ let mut body = GrpcBody :: Empty ;
610+ let waker = noop_waker ( ) ;
611+ let mut cx = Context :: from_waker ( & waker) ;
612+
613+ // Empty body returns None immediately
614+ let frame: Poll < Option < Result < hyper:: body:: Frame < Bytes > , hyper:: Error > > > =
615+ Pin :: new ( & mut body) . poll_frame ( & mut cx) ;
616+ assert ! ( matches!( frame, Poll :: Ready ( None ) ) ) ;
617+ }
448618}
0 commit comments