Skip to content

Commit 1d92fff

Browse files
committed
Fix gRPC stream status and timeout validation
1 parent 59f845c commit 1d92fff

6 files changed

Lines changed: 257 additions & 81 deletions

File tree

Cargo.lock

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

e2e-tests/Cargo.lock

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ldk-server-client/Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,11 @@ ldk-server-protos = { path = "../ldk-server-protos" }
1212
reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls"] }
1313
prost = { version = "0.11.6", default-features = false, features = ["std", "prost-derive"] }
1414
bitcoin_hashes = "0.14"
15+
http-body = "0.4"
16+
hyper = { version = "0.14", default-features = false, features = ["client", "http2", "runtime", "tcp"] }
17+
hyper-rustls = { version = "0.24", default-features = false, features = ["http2", "tls12", "tokio-runtime"] }
18+
rustls = "0.21"
19+
rustls-pemfile = "1"
20+
21+
[dev-dependencies]
22+
tokio = { version = "1", default-features = false, features = ["macros", "rt"] }

ldk-server-client/src/client.rs

Lines changed: 178 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
// You may not use this file except in accordance with one or both of these
88
// licenses.
99

10+
use std::io::Cursor;
1011
use std::time::{SystemTime, UNIX_EPOCH};
1112

1213
use bitcoin_hashes::hmac::{Hmac, HmacEngine};
1314
use bitcoin_hashes::{sha256, Hash, HashEngine};
15+
use hyper::body::HttpBody as _;
16+
use hyper::{Body as HyperBody, Client as HyperClient, Request as HyperRequest, Version};
17+
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
1418
use ldk_server_protos::api::SubscribeEventsRequest;
1519
use ldk_server_protos::api::{
1620
Bolt11ClaimForHashRequest, Bolt11ClaimForHashResponse, Bolt11FailForHashRequest,
@@ -50,7 +54,9 @@ use ldk_server_protos::endpoints::{
5054
};
5155
use ldk_server_protos::events::EventEnvelope;
5256
use prost::Message;
53-
use reqwest::{Certificate, Client};
57+
use reqwest::{header::HeaderMap, Certificate, Client};
58+
use rustls::{ClientConfig, RootCertStore};
59+
use rustls_pemfile::certs;
5460

5561
use crate::error::LdkServerError;
5662
use crate::error::LdkServerErrorCode::{
@@ -60,6 +66,8 @@ use crate::error::LdkServerErrorCode::{
6066
/// gRPC path prefix for the LightningNode service.
6167
const GRPC_SERVICE_PREFIX: &str = "/api.LightningNode/";
6268

69+
type StreamingClient = HyperClient<HttpsConnector<hyper::client::HttpConnector>, HyperBody>;
70+
6371
/// Client to access a hosted instance of LDK Server via gRPC.
6472
///
6573
/// The client requires the server's TLS certificate to be provided for verification.
@@ -69,6 +77,7 @@ const GRPC_SERVICE_PREFIX: &str = "/api.LightningNode/";
6977
pub struct LdkServerClient {
7078
base_url: String,
7179
client: Client,
80+
streaming_client: StreamingClient,
7281
api_key: String,
7382
}
7483

@@ -82,13 +91,14 @@ impl LdkServerClient {
8291
pub fn new(base_url: String, api_key: String, server_cert_pem: &[u8]) -> Result<Self, String> {
8392
let cert = Certificate::from_pem(server_cert_pem)
8493
.map_err(|e| format!("Failed to parse server certificate: {e}"))?;
94+
let streaming_client = build_streaming_client(server_cert_pem)?;
8595

8696
let client = Client::builder()
8797
.add_root_certificate(cert)
8898
.build()
8999
.map_err(|e| format!("Failed to build HTTP client: {e}"))?;
90100

91-
Ok(Self { base_url, client, api_key })
101+
Ok(Self { base_url, client, streaming_client, api_key })
92102
}
93103

94104
/// Computes the HMAC-SHA256 authentication header value.
@@ -417,36 +427,32 @@ impl LdkServerClient {
417427
let auth_header = self.compute_auth_header();
418428

419429
let response = self
420-
.client
421-
.post(&url)
422-
.header("content-type", "application/grpc+proto")
423-
.header("te", "trailers")
424-
.header("x-auth", auth_header)
425-
.body(grpc_body)
426-
.send()
430+
.streaming_client
431+
.request(
432+
HyperRequest::post(&url)
433+
.version(Version::HTTP_2)
434+
.header("content-type", "application/grpc+proto")
435+
.header("te", "trailers")
436+
.header("x-auth", auth_header)
437+
.body(HyperBody::from(grpc_body))
438+
.map_err(|e| {
439+
LdkServerError::new(
440+
InternalError,
441+
format!("Failed to build gRPC request: {e}"),
442+
)
443+
})?,
444+
)
427445
.await
428446
.map_err(|e| {
429447
LdkServerError::new(InternalError, format!("gRPC request failed: {}", e))
430448
})?;
431449

432-
// Check for Trailers-Only error
433-
if let Some(status_val) = response.headers().get("grpc-status") {
434-
if let Ok(status_str) = status_val.to_str() {
435-
if let Ok(code) = status_str.parse::<u32>() {
436-
if code != 0 {
437-
let message = response
438-
.headers()
439-
.get("grpc-message")
440-
.and_then(|v| v.to_str().ok())
441-
.map(percent_decode)
442-
.unwrap_or_default();
443-
return Err(grpc_code_to_error(code, message));
444-
}
445-
}
446-
}
450+
let (parts, body) = response.into_parts();
451+
if let Some(error) = grpc_error_from_headers(&parts.headers) {
452+
return Err(error);
447453
}
448454

449-
Ok(EventStream { response, buf: Vec::new() })
455+
Ok(EventStream { body, buf: Vec::new(), trailers_checked: false })
450456
}
451457

452458
/// Send a unary gRPC request and decode the response.
@@ -479,20 +485,8 @@ impl LdkServerClient {
479485
// Check for Trailers-Only error responses (grpc-status in response headers).
480486
// In gRPC, when there is no response body (error case), the server sends
481487
// grpc-status as part of the initial HEADERS frame, readable as a regular header.
482-
if let Some(status_val) = response.headers().get("grpc-status") {
483-
if let Ok(status_str) = status_val.to_str() {
484-
if let Ok(code) = status_str.parse::<u32>() {
485-
if code != 0 {
486-
let message = response
487-
.headers()
488-
.get("grpc-message")
489-
.and_then(|v| v.to_str().ok())
490-
.map(percent_decode)
491-
.unwrap_or_default();
492-
return Err(grpc_code_to_error(code, message));
493-
}
494-
}
495-
}
488+
if let Some(error) = grpc_error_from_headers(response.headers()) {
489+
return Err(error);
496490
}
497491

498492
// Read the response body
@@ -521,14 +515,42 @@ impl LdkServerClient {
521515

522516
/// Map a gRPC status code to an LdkServerError.
523517
fn grpc_code_to_error(code: u32, message: String) -> LdkServerError {
524-
let error_code = match code {
525-
3 => InvalidRequestError, // INVALID_ARGUMENT
526-
16 => AuthError, // UNAUTHENTICATED
527-
9 => LightningError, // FAILED_PRECONDITION
528-
13 => InternalServerError, // INTERNAL
529-
_ => InternalError,
530-
};
531-
LdkServerError::new(error_code, message)
518+
match code {
519+
3 => LdkServerError::new(InvalidRequestError, message), // INVALID_ARGUMENT
520+
9 => LdkServerError::new(LightningError, message), // FAILED_PRECONDITION
521+
13 => LdkServerError::new(InternalServerError, message), // INTERNAL
522+
14 => LdkServerError::new(
523+
InternalError,
524+
if message.is_empty() {
525+
"gRPC stream became unavailable".to_string()
526+
} else {
527+
format!("gRPC stream became unavailable: {message}")
528+
},
529+
),
530+
16 => LdkServerError::new(AuthError, message), // UNAUTHENTICATED
531+
_ => LdkServerError::new(
532+
InternalError,
533+
if message.is_empty() {
534+
format!("gRPC status {code}")
535+
} else {
536+
format!("gRPC status {code}: {message}")
537+
},
538+
),
539+
}
540+
}
541+
542+
fn grpc_error_from_headers(headers: &HeaderMap) -> Option<LdkServerError> {
543+
let code = headers.get("grpc-status")?.to_str().ok()?.parse::<u32>().ok()?;
544+
if code == 0 {
545+
return None;
546+
}
547+
548+
let message = headers
549+
.get("grpc-message")
550+
.and_then(|v| v.to_str().ok())
551+
.map(percent_decode)
552+
.unwrap_or_default();
553+
Some(grpc_code_to_error(code, message))
532554
}
533555

534556
/// Minimal percent-decoding for grpc-message values.
@@ -562,8 +584,9 @@ fn hex_val(b: u8) -> Option<u8> {
562584
///
563585
/// Call [`next_event`](EventStream::next_event) to receive the next event from the server.
564586
pub struct EventStream {
565-
response: reqwest::Response,
587+
body: hyper::Body,
566588
buf: Vec<u8>,
589+
trailers_checked: bool,
567590
}
568591

569592
impl EventStream {
@@ -588,23 +611,126 @@ impl EventStream {
588611
}
589612

590613
// Need more data — read the next chunk from the response body
591-
match self.response.chunk().await {
592-
Ok(Some(chunk)) => self.buf.extend_from_slice(&chunk),
593-
Ok(None) => return None, // stream ended
594-
Err(e) => {
614+
match self.body.data().await {
615+
Some(Ok(chunk)) => self.buf.extend_from_slice(&chunk),
616+
Some(Err(e)) => {
595617
return Some(Err(LdkServerError::new(
596618
InternalError,
597619
format!("Failed to read event stream: {}", e),
598620
)));
599621
},
622+
None => {
623+
if self.trailers_checked {
624+
return None;
625+
}
626+
self.trailers_checked = true;
627+
return self.finish_stream().await;
628+
},
600629
}
601630
}
602631
}
632+
633+
async fn finish_stream(&mut self) -> Option<Result<EventEnvelope, LdkServerError>> {
634+
match self.body.trailers().await {
635+
Ok(Some(trailers)) => {
636+
if let Some(error) = grpc_error_from_headers(&trailers) {
637+
return Some(Err(error));
638+
}
639+
},
640+
Ok(None) => {},
641+
Err(e) => {
642+
return Some(Err(LdkServerError::new(
643+
InternalError,
644+
format!("Failed to read event stream trailers: {}", e),
645+
)));
646+
},
647+
}
648+
649+
if self.buf.is_empty() {
650+
None
651+
} else {
652+
Some(Err(LdkServerError::new(
653+
InternalError,
654+
"Event stream ended with an incomplete gRPC frame",
655+
)))
656+
}
657+
}
658+
}
659+
660+
fn build_streaming_client(server_cert_pem: &[u8]) -> Result<StreamingClient, String> {
661+
let mut pem_reader = Cursor::new(server_cert_pem);
662+
let certs =
663+
certs(&mut pem_reader).map_err(|e| format!("Failed to parse server certificate: {e}"))?;
664+
if certs.is_empty() {
665+
return Err("Failed to parse server certificate: no certificates found in PEM".to_string());
666+
}
667+
668+
let mut roots = RootCertStore::empty();
669+
let (added, _ignored) = roots.add_parsable_certificates(&certs);
670+
if added == 0 {
671+
return Err("Failed to build streaming client: certificate was not accepted".to_string());
672+
}
673+
674+
let tls_config = ClientConfig::builder()
675+
.with_safe_defaults()
676+
.with_root_certificates(roots)
677+
.with_no_client_auth();
678+
let connector = HttpsConnectorBuilder::new()
679+
.with_tls_config(tls_config)
680+
.https_only()
681+
.enable_http2()
682+
.build();
683+
684+
Ok(HyperClient::builder().http2_only(true).build(connector))
603685
}
604686

605687
#[cfg(test)]
606688
mod tests {
607689
use super::*;
690+
use hyper::Body;
691+
use reqwest::header::HeaderValue;
692+
693+
#[test]
694+
fn test_grpc_error_from_headers_ignores_ok_status() {
695+
let mut headers = HeaderMap::new();
696+
headers.insert("grpc-status", HeaderValue::from_static("0"));
697+
assert!(grpc_error_from_headers(&headers).is_none());
698+
}
699+
700+
#[test]
701+
fn test_grpc_error_from_headers_decodes_message() {
702+
let mut headers = HeaderMap::new();
703+
headers.insert("grpc-status", HeaderValue::from_static("3"));
704+
headers.insert("grpc-message", HeaderValue::from_static("bad%20request"));
705+
706+
let err = grpc_error_from_headers(&headers).unwrap();
707+
assert_eq!(err.error_code, InvalidRequestError);
708+
assert_eq!(err.message, "bad request");
709+
}
710+
711+
#[test]
712+
fn test_grpc_code_to_error_marks_unavailable_streams() {
713+
let err = grpc_code_to_error(14, "server shutting down".to_string());
714+
assert_eq!(err.error_code, InternalError);
715+
assert_eq!(err.message, "gRPC stream became unavailable: server shutting down");
716+
}
717+
718+
#[tokio::test]
719+
async fn test_event_stream_surfaces_terminal_grpc_status() {
720+
let (mut sender, body) = Body::channel();
721+
let mut trailers = HeaderMap::new();
722+
trailers.insert("grpc-status", HeaderValue::from_static("14"));
723+
trailers.insert("grpc-message", HeaderValue::from_static("server%20restarting"));
724+
sender.send_trailers(trailers).await.unwrap();
725+
drop(sender);
726+
727+
let mut stream = EventStream { body, buf: Vec::new(), trailers_checked: false };
728+
729+
let result = stream.next_event().await.unwrap().unwrap_err();
730+
assert_eq!(result.error_code, InternalError);
731+
assert_eq!(result.message, "gRPC stream became unavailable: server restarting");
732+
assert!(stream.next_event().await.is_none());
733+
}
608734

609735
#[test]
610736
fn test_hex_val() {

0 commit comments

Comments
 (0)