Skip to content

Commit 5cd1ec5

Browse files
committed
Fix gRPC stream status and timeout validation
1 parent 854f955 commit 5cd1ec5

6 files changed

Lines changed: 257 additions & 81 deletions

File tree

Cargo.lock

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

e2e-tests/Cargo.lock

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ldk-server-client/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,11 @@ ldk-server-protos = { path = "../ldk-server-protos" }
1212
reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls"] }
1313
prost = { version = "0.11.6", default-features = false, features = ["std", "prost-derive"] }
1414
bitcoin_hashes = "0.14"
15+
http-body = "0.4"
16+
hyper = { version = "0.14", default-features = false, features = ["client", "http2", "runtime", "tcp"] }
17+
hyper-rustls = { version = "0.24", default-features = false, features = ["http2", "tls12", "tokio-runtime"] }
18+
rustls = "0.21"
19+
rustls-pemfile = "1"
20+
21+
[dev-dependencies]
22+
tokio = { version = "1", default-features = false, features = ["macros", "rt"] }

ldk-server-client/src/client.rs

Lines changed: 178 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
// You may not use this file except in accordance with one or both of these
88
// licenses.
99

10+
use std::io::Cursor;
1011
use std::time::{SystemTime, UNIX_EPOCH};
1112

1213
use bitcoin_hashes::hmac::{Hmac, HmacEngine};
1314
use bitcoin_hashes::{sha256, Hash, HashEngine};
15+
use hyper::body::HttpBody as _;
16+
use hyper::{Body as HyperBody, Client as HyperClient, Request as HyperRequest, Version};
17+
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
1418
use ldk_server_protos::api::SubscribeEventsRequest;
1519
use ldk_server_protos::api::{
1620
Bolt11ClaimForHashRequest, Bolt11ClaimForHashResponse, Bolt11FailForHashRequest,
@@ -50,7 +54,9 @@ use ldk_server_protos::endpoints::{
5054
};
5155
use ldk_server_protos::events::EventEnvelope;
5256
use prost::Message;
53-
use reqwest::{Certificate, Client};
57+
use reqwest::{header::HeaderMap, Certificate, Client};
58+
use rustls::{ClientConfig, RootCertStore};
59+
use rustls_pemfile::certs;
5460

5561
use crate::error::LdkServerError;
5662
use crate::error::LdkServerErrorCode::{
@@ -60,6 +66,8 @@ use crate::error::LdkServerErrorCode::{
6066
/// gRPC path prefix for the LightningNode service.
6167
const GRPC_SERVICE_PREFIX: &str = "/api.LightningNode/";
6268

69+
type StreamingClient = HyperClient<HttpsConnector<hyper::client::HttpConnector>, HyperBody>;
70+
6371
/// Client to access a hosted instance of LDK Server via gRPC.
6472
///
6573
/// The client requires the server's TLS certificate to be provided for verification.
@@ -69,6 +77,7 @@ const GRPC_SERVICE_PREFIX: &str = "/api.LightningNode/";
6977
pub struct LdkServerClient {
7078
base_url: String,
7179
client: Client,
80+
streaming_client: StreamingClient,
7281
api_key: String,
7382
}
7483

@@ -82,13 +91,14 @@ impl LdkServerClient {
8291
pub fn new(base_url: String, api_key: String, server_cert_pem: &[u8]) -> Result<Self, String> {
8392
let cert = Certificate::from_pem(server_cert_pem)
8493
.map_err(|e| format!("Failed to parse server certificate: {e}"))?;
94+
let streaming_client = build_streaming_client(server_cert_pem)?;
8595

8696
let client = Client::builder()
8797
.add_root_certificate(cert)
8898
.build()
8999
.map_err(|e| format!("Failed to build HTTP client: {e}"))?;
90100

91-
Ok(Self { base_url, client, api_key })
101+
Ok(Self { base_url, client, streaming_client, api_key })
92102
}
93103

94104
/// Computes the HMAC-SHA256 authentication header value.
@@ -411,36 +421,32 @@ impl LdkServerClient {
411421
let auth_header = self.compute_auth_header();
412422

413423
let response = self
414-
.client
415-
.post(&url)
416-
.header("content-type", "application/grpc+proto")
417-
.header("te", "trailers")
418-
.header("x-auth", auth_header)
419-
.body(grpc_body)
420-
.send()
424+
.streaming_client
425+
.request(
426+
HyperRequest::post(&url)
427+
.version(Version::HTTP_2)
428+
.header("content-type", "application/grpc+proto")
429+
.header("te", "trailers")
430+
.header("x-auth", auth_header)
431+
.body(HyperBody::from(grpc_body))
432+
.map_err(|e| {
433+
LdkServerError::new(
434+
InternalError,
435+
format!("Failed to build gRPC request: {e}"),
436+
)
437+
})?,
438+
)
421439
.await
422440
.map_err(|e| {
423441
LdkServerError::new(InternalError, format!("gRPC request failed: {}", e))
424442
})?;
425443

426-
// Check for Trailers-Only error
427-
if let Some(status_val) = response.headers().get("grpc-status") {
428-
if let Ok(status_str) = status_val.to_str() {
429-
if let Ok(code) = status_str.parse::<u32>() {
430-
if code != 0 {
431-
let message = response
432-
.headers()
433-
.get("grpc-message")
434-
.and_then(|v| v.to_str().ok())
435-
.map(percent_decode)
436-
.unwrap_or_default();
437-
return Err(grpc_code_to_error(code, message));
438-
}
439-
}
440-
}
444+
let (parts, body) = response.into_parts();
445+
if let Some(error) = grpc_error_from_headers(&parts.headers) {
446+
return Err(error);
441447
}
442448

443-
Ok(EventStream { response, buf: Vec::new() })
449+
Ok(EventStream { body, buf: Vec::new(), trailers_checked: false })
444450
}
445451

446452
/// Send a unary gRPC request and decode the response.
@@ -473,20 +479,8 @@ impl LdkServerClient {
473479
// Check for Trailers-Only error responses (grpc-status in response headers).
474480
// In gRPC, when there is no response body (error case), the server sends
475481
// grpc-status as part of the initial HEADERS frame, readable as a regular header.
476-
if let Some(status_val) = response.headers().get("grpc-status") {
477-
if let Ok(status_str) = status_val.to_str() {
478-
if let Ok(code) = status_str.parse::<u32>() {
479-
if code != 0 {
480-
let message = response
481-
.headers()
482-
.get("grpc-message")
483-
.and_then(|v| v.to_str().ok())
484-
.map(percent_decode)
485-
.unwrap_or_default();
486-
return Err(grpc_code_to_error(code, message));
487-
}
488-
}
489-
}
482+
if let Some(error) = grpc_error_from_headers(response.headers()) {
483+
return Err(error);
490484
}
491485

492486
// Read the response body
@@ -515,14 +509,42 @@ impl LdkServerClient {
515509

516510
/// Map a gRPC status code to an LdkServerError.
517511
fn grpc_code_to_error(code: u32, message: String) -> LdkServerError {
518-
let error_code = match code {
519-
3 => InvalidRequestError, // INVALID_ARGUMENT
520-
16 => AuthError, // UNAUTHENTICATED
521-
9 => LightningError, // FAILED_PRECONDITION
522-
13 => InternalServerError, // INTERNAL
523-
_ => InternalError,
524-
};
525-
LdkServerError::new(error_code, message)
512+
match code {
513+
3 => LdkServerError::new(InvalidRequestError, message), // INVALID_ARGUMENT
514+
9 => LdkServerError::new(LightningError, message), // FAILED_PRECONDITION
515+
13 => LdkServerError::new(InternalServerError, message), // INTERNAL
516+
14 => LdkServerError::new(
517+
InternalError,
518+
if message.is_empty() {
519+
"gRPC stream became unavailable".to_string()
520+
} else {
521+
format!("gRPC stream became unavailable: {message}")
522+
},
523+
),
524+
16 => LdkServerError::new(AuthError, message), // UNAUTHENTICATED
525+
_ => LdkServerError::new(
526+
InternalError,
527+
if message.is_empty() {
528+
format!("gRPC status {code}")
529+
} else {
530+
format!("gRPC status {code}: {message}")
531+
},
532+
),
533+
}
534+
}
535+
536+
fn grpc_error_from_headers(headers: &HeaderMap) -> Option<LdkServerError> {
537+
let code = headers.get("grpc-status")?.to_str().ok()?.parse::<u32>().ok()?;
538+
if code == 0 {
539+
return None;
540+
}
541+
542+
let message = headers
543+
.get("grpc-message")
544+
.and_then(|v| v.to_str().ok())
545+
.map(percent_decode)
546+
.unwrap_or_default();
547+
Some(grpc_code_to_error(code, message))
526548
}
527549

528550
/// Minimal percent-decoding for grpc-message values.
@@ -556,8 +578,9 @@ fn hex_val(b: u8) -> Option<u8> {
556578
///
557579
/// Call [`next_event`](EventStream::next_event) to receive the next event from the server.
558580
pub struct EventStream {
559-
response: reqwest::Response,
581+
body: hyper::Body,
560582
buf: Vec<u8>,
583+
trailers_checked: bool,
561584
}
562585

563586
impl EventStream {
@@ -582,23 +605,126 @@ impl EventStream {
582605
}
583606

584607
// Need more data — read the next chunk from the response body
585-
match self.response.chunk().await {
586-
Ok(Some(chunk)) => self.buf.extend_from_slice(&chunk),
587-
Ok(None) => return None, // stream ended
588-
Err(e) => {
608+
match self.body.data().await {
609+
Some(Ok(chunk)) => self.buf.extend_from_slice(&chunk),
610+
Some(Err(e)) => {
589611
return Some(Err(LdkServerError::new(
590612
InternalError,
591613
format!("Failed to read event stream: {}", e),
592614
)));
593615
},
616+
None => {
617+
if self.trailers_checked {
618+
return None;
619+
}
620+
self.trailers_checked = true;
621+
return self.finish_stream().await;
622+
},
594623
}
595624
}
596625
}
626+
627+
async fn finish_stream(&mut self) -> Option<Result<EventEnvelope, LdkServerError>> {
628+
match self.body.trailers().await {
629+
Ok(Some(trailers)) => {
630+
if let Some(error) = grpc_error_from_headers(&trailers) {
631+
return Some(Err(error));
632+
}
633+
},
634+
Ok(None) => {},
635+
Err(e) => {
636+
return Some(Err(LdkServerError::new(
637+
InternalError,
638+
format!("Failed to read event stream trailers: {}", e),
639+
)));
640+
},
641+
}
642+
643+
if self.buf.is_empty() {
644+
None
645+
} else {
646+
Some(Err(LdkServerError::new(
647+
InternalError,
648+
"Event stream ended with an incomplete gRPC frame",
649+
)))
650+
}
651+
}
652+
}
653+
654+
fn build_streaming_client(server_cert_pem: &[u8]) -> Result<StreamingClient, String> {
655+
let mut pem_reader = Cursor::new(server_cert_pem);
656+
let certs =
657+
certs(&mut pem_reader).map_err(|e| format!("Failed to parse server certificate: {e}"))?;
658+
if certs.is_empty() {
659+
return Err("Failed to parse server certificate: no certificates found in PEM".to_string());
660+
}
661+
662+
let mut roots = RootCertStore::empty();
663+
let (added, _ignored) = roots.add_parsable_certificates(&certs);
664+
if added == 0 {
665+
return Err("Failed to build streaming client: certificate was not accepted".to_string());
666+
}
667+
668+
let tls_config = ClientConfig::builder()
669+
.with_safe_defaults()
670+
.with_root_certificates(roots)
671+
.with_no_client_auth();
672+
let connector = HttpsConnectorBuilder::new()
673+
.with_tls_config(tls_config)
674+
.https_only()
675+
.enable_http2()
676+
.build();
677+
678+
Ok(HyperClient::builder().http2_only(true).build(connector))
597679
}
598680

599681
#[cfg(test)]
600682
mod tests {
601683
use super::*;
684+
use hyper::Body;
685+
use reqwest::header::HeaderValue;
686+
687+
#[test]
688+
fn test_grpc_error_from_headers_ignores_ok_status() {
689+
let mut headers = HeaderMap::new();
690+
headers.insert("grpc-status", HeaderValue::from_static("0"));
691+
assert!(grpc_error_from_headers(&headers).is_none());
692+
}
693+
694+
#[test]
695+
fn test_grpc_error_from_headers_decodes_message() {
696+
let mut headers = HeaderMap::new();
697+
headers.insert("grpc-status", HeaderValue::from_static("3"));
698+
headers.insert("grpc-message", HeaderValue::from_static("bad%20request"));
699+
700+
let err = grpc_error_from_headers(&headers).unwrap();
701+
assert_eq!(err.error_code, InvalidRequestError);
702+
assert_eq!(err.message, "bad request");
703+
}
704+
705+
#[test]
706+
fn test_grpc_code_to_error_marks_unavailable_streams() {
707+
let err = grpc_code_to_error(14, "server shutting down".to_string());
708+
assert_eq!(err.error_code, InternalError);
709+
assert_eq!(err.message, "gRPC stream became unavailable: server shutting down");
710+
}
711+
712+
#[tokio::test]
713+
async fn test_event_stream_surfaces_terminal_grpc_status() {
714+
let (mut sender, body) = Body::channel();
715+
let mut trailers = HeaderMap::new();
716+
trailers.insert("grpc-status", HeaderValue::from_static("14"));
717+
trailers.insert("grpc-message", HeaderValue::from_static("server%20restarting"));
718+
sender.send_trailers(trailers).await.unwrap();
719+
drop(sender);
720+
721+
let mut stream = EventStream { body, buf: Vec::new(), trailers_checked: false };
722+
723+
let result = stream.next_event().await.unwrap().unwrap_err();
724+
assert_eq!(result.error_code, InternalError);
725+
assert_eq!(result.message, "gRPC stream became unavailable: server restarting");
726+
assert!(stream.next_event().await.is_none());
727+
}
602728

603729
#[test]
604730
fn test_hex_val() {

0 commit comments

Comments
 (0)