Skip to content

Commit 43b12a9

Browse files
committed
feat: use axum for ohttp gateway middleware
1 parent db8f835 commit 43b12a9

File tree

9 files changed

+276
-339
lines changed

9 files changed

+276
-339
lines changed

Cargo-minimal.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2441,6 +2441,8 @@ dependencies = [
24412441
name = "ohttp-relay"
24422442
version = "0.0.11"
24432443
dependencies = [
2444+
"bhttp",
2445+
"bitcoin-ohttp",
24442446
"byteorder",
24452447
"bytes",
24462448
"futures",
@@ -2695,8 +2697,6 @@ dependencies = [
26952697
"bitcoin-ohttp",
26962698
"clap",
26972699
"config",
2698-
"http-body-util",
2699-
"hyper",
27002700
"ohttp-relay",
27012701
"payjoin-directory",
27022702
"payjoin-test-utils",

Cargo-recent.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2441,6 +2441,8 @@ dependencies = [
24412441
name = "ohttp-relay"
24422442
version = "0.0.11"
24432443
dependencies = [
2444+
"bhttp",
2445+
"bitcoin-ohttp",
24442446
"byteorder",
24452447
"bytes",
24462448
"futures",
@@ -2695,8 +2697,6 @@ dependencies = [
26952697
"bitcoin-ohttp",
26962698
"clap",
26972699
"config",
2698-
"http-body-util",
2699-
"hyper",
27002700
"ohttp-relay",
27012701
"payjoin-directory",
27022702
"payjoin-test-utils",

ohttp-relay/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ tokio-util = { version = "0.7.16", features = ["net", "codec"] }
4747
tower = "0.5"
4848
tracing = "0.1.41"
4949
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
50+
ohttp = { package = "bitcoin-ohttp", version = "0.6" }
51+
bhttp = { version = "0.6.1", features = ["http"] }
5052

5153
[dev-dependencies]
5254
mockito = "1.7.0"

ohttp-relay/src/gateway_helpers.rs

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
use std::io::Cursor;
2+
3+
pub const CHACHA20_POLY1305_NONCE_LEN: usize = 32;
4+
pub const POLY1305_TAG_SIZE: usize = 16;
5+
pub const ENCAPSULATED_MESSAGE_BYTES: usize = 65536;
6+
pub const BHTTP_REQ_BYTES: usize =
7+
ENCAPSULATED_MESSAGE_BYTES - (CHACHA20_POLY1305_NONCE_LEN + POLY1305_TAG_SIZE);
8+
9+
#[derive(Debug)]
10+
pub enum GatewayError {
11+
BadRequest(String),
12+
OhttpKeyRejection(String),
13+
InternalServerError(String),
14+
}
15+
16+
impl std::fmt::Display for GatewayError {
17+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18+
match self {
19+
GatewayError::BadRequest(msg) => write!(f, "Bad request: {}", msg),
20+
GatewayError::OhttpKeyRejection(msg) => write!(f, "OHTTP key rejection: {}", msg),
21+
GatewayError::InternalServerError(msg) => write!(f, "Internal server error: {}", msg),
22+
}
23+
}
24+
}
25+
26+
impl std::error::Error for GatewayError {}
27+
28+
/// Represents the decapsulated HTTP request extracted from OHTTP
29+
pub struct DecapsulatedRequest {
30+
pub method: String,
31+
pub uri: String,
32+
pub headers: Vec<(String, String)>,
33+
pub body: Vec<u8>,
34+
}
35+
36+
pub fn decapsulate_ohttp_request(
37+
ohttp_body: &[u8],
38+
ohttp_server: &ohttp::Server,
39+
) -> Result<(DecapsulatedRequest, ohttp::ServerResponse), GatewayError> {
40+
let (bhttp_req, res_ctx) = ohttp_server.decapsulate(ohttp_body).map_err(|e| {
41+
GatewayError::OhttpKeyRejection(format!("OHTTP decapsulation failed: {}", e))
42+
})?;
43+
44+
let mut cursor = Cursor::new(bhttp_req);
45+
let bhttp_msg = bhttp::Message::read_bhttp(&mut cursor)
46+
.map_err(|e| GatewayError::BadRequest(format!("Invalid BHTTP: {}", e)))?;
47+
48+
let method = String::from_utf8(bhttp_msg.control().method().unwrap_or_default().to_vec())
49+
.unwrap_or_else(|_| "GET".to_string());
50+
51+
let uri = format!(
52+
"{}://{}{}",
53+
std::str::from_utf8(bhttp_msg.control().scheme().unwrap_or_default()).unwrap_or("https"),
54+
std::str::from_utf8(bhttp_msg.control().authority().unwrap_or_default())
55+
.unwrap_or("localhost"),
56+
std::str::from_utf8(bhttp_msg.control().path().unwrap_or_default()).unwrap_or("/")
57+
);
58+
59+
let mut headers = Vec::new();
60+
for field in bhttp_msg.header().fields() {
61+
let name = String::from_utf8_lossy(field.name()).to_string();
62+
let value = String::from_utf8_lossy(field.value()).to_string();
63+
headers.push((name, value));
64+
}
65+
66+
let body = bhttp_msg.content().to_vec();
67+
68+
Ok((DecapsulatedRequest { method, uri, headers, body }, res_ctx))
69+
}
70+
71+
pub fn encapsulate_ohttp_response(
72+
status_code: u16,
73+
headers: Vec<(String, String)>,
74+
body: Vec<u8>,
75+
res_ctx: ohttp::ServerResponse,
76+
) -> Result<Vec<u8>, GatewayError> {
77+
let bhttp_status = bhttp::StatusCode::try_from(status_code)
78+
.map_err(|e| GatewayError::InternalServerError(format!("Invalid status code: {}", e)))?;
79+
80+
let mut bhttp_res = bhttp::Message::response(bhttp_status);
81+
82+
for (name, value) in &headers {
83+
bhttp_res.put_header(name.as_str(), value.as_str());
84+
}
85+
86+
bhttp_res.write_content(&body);
87+
88+
let mut bhttp_bytes = Vec::new();
89+
bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).map_err(|e| {
90+
GatewayError::InternalServerError(format!("BHTTP serialization failed: {}", e))
91+
})?;
92+
93+
bhttp_bytes.resize(BHTTP_REQ_BYTES, 0);
94+
95+
let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).map_err(|e| {
96+
GatewayError::InternalServerError(format!("OHTTP encapsulation failed: {}", e))
97+
})?;
98+
99+
assert!(
100+
ohttp_res.len() == ENCAPSULATED_MESSAGE_BYTES,
101+
"Unexpected OHTTP response size: {} != {}",
102+
ohttp_res.len(),
103+
ENCAPSULATED_MESSAGE_BYTES
104+
);
105+
106+
Ok(ohttp_res)
107+
}

ohttp-relay/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ pub mod gateway_prober;
3636
mod gateway_uri;
3737
pub mod sentinel;
3838
pub use sentinel::SentinelTag;
39+
pub mod gateway_helpers;
40+
41+
pub use gateway_helpers::{
42+
decapsulate_ohttp_request, encapsulate_ohttp_response, BHTTP_REQ_BYTES,
43+
ENCAPSULATED_MESSAGE_BYTES,
44+
};
3945

4046
use crate::error::{BoxError, Error};
4147

payjoin-service/Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ tokio-stream = { version = "0.1.17", optional = true }
4646
tower = "0.5"
4747
tracing = "0.1"
4848
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
49-
http-body-util = "0.1"
50-
hyper = "1.8.1"
5149
ohttp = { package = "bitcoin-ohttp", version = "0.6" }
5250
bhttp = { version = "0.6.1", features = ["http"] }
5351

payjoin-service/src/lib.rs

Lines changed: 51 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,34 @@ use config::Config;
1010
use ohttp_relay::SentinelTag;
1111
use rand::Rng;
1212
use tokio_listener::{Listener, SystemOptions, UserOptions};
13-
use tower::{Service, ServiceBuilder};
13+
use tower::{Service, ServiceBuilder, ServiceExt};
1414
use tracing::info;
15-
pub mod ohttp;
16-
17-
use http_body_util::combinators::BoxBody;
18-
use hyper::body::Bytes;
19-
use hyper::{Request, StatusCode};
20-
use ohttp::{OhttpGatewayConfig, OhttpGatewayLayer};
21-
use tower::ServiceExt;
2215

2316
pub mod cli;
2417
pub mod config;
2518
pub mod metrics;
2619
pub mod middleware;
20+
pub mod ohttp;
2721

2822
use crate::metrics::MetricsService;
2923
use crate::middleware::{track_connections, track_metrics};
24+
use crate::ohttp::OhttpGatewayConfig;
3025

3126
#[derive(Clone)]
3227
struct Services {
3328
directory: payjoin_directory::Service<payjoin_directory::FilesDb>,
3429
relay: ohttp_relay::Service,
35-
sentinel_tag: SentinelTag,
30+
ohttp_config: OhttpGatewayConfig,
3631
}
3732

3833
pub async fn serve(config: Config) -> anyhow::Result<()> {
3934
let sentinel_tag = generate_sentinel_tag();
4035
let metrics = MetricsService::new()?;
36+
let directory = init_directory(&config, sentinel_tag).await?;
37+
let ohttp_config = OhttpGatewayConfig::new(directory.ohttp.clone(), sentinel_tag);
4138

42-
let services = Services {
43-
directory: init_directory(&config, sentinel_tag).await?,
44-
relay: ohttp_relay::Service::new(sentinel_tag).await,
45-
sentinel_tag,
46-
};
39+
let services =
40+
Services { directory, relay: ohttp_relay::Service::new(sentinel_tag).await, ohttp_config };
4741

4842
let app = build_app(services, metrics.clone());
4943
let _ = spawn_metrics_server(config.metrics.listener.clone(), metrics).await?;
@@ -70,16 +64,17 @@ pub async fn serve_manual_tls(
7064
tls_config: Option<axum_server::tls_rustls::RustlsConfig>,
7165
root_store: rustls::RootCertStore,
7266
) -> anyhow::Result<(u16, u16, tokio::task::JoinHandle<anyhow::Result<()>>)> {
73-
use std::net::SocketAddr;
74-
7567
let sentinel_tag = generate_sentinel_tag();
7668
let metrics = MetricsService::new()?;
69+
let directory = init_directory(&config, sentinel_tag).await?;
70+
let ohttp_config = OhttpGatewayConfig::new(directory.ohttp.clone(), sentinel_tag);
7771

7872
let services = Services {
79-
directory: init_directory(&config, sentinel_tag).await?,
73+
directory,
8074
relay: ohttp_relay::Service::new_with_roots(root_store, sentinel_tag).await,
81-
sentinel_tag,
75+
ohttp_config,
8276
};
77+
8378
let app = build_app(services, metrics.clone());
8479
let metrics_port = spawn_metrics_server(config.metrics.listener.clone(), metrics).await?;
8580

@@ -126,11 +121,12 @@ pub async fn serve_acme(config: Config) -> anyhow::Result<()> {
126121

127122
let sentinel_tag = generate_sentinel_tag();
128123
let metrics = MetricsService::new()?;
124+
let directory = init_directory(&config, sentinel_tag).await?;
125+
let ohttp_config = OhttpGatewayConfig::new(directory.ohttp.clone(), sentinel_tag);
126+
127+
let services =
128+
Services { directory, relay: ohttp_relay::Service::new(sentinel_tag).await, ohttp_config };
129129

130-
let services = Services {
131-
directory: init_directory(&config, sentinel_tag).await?,
132-
relay: ohttp_relay::Service::new(sentinel_tag).await,
133-
};
134130
let app = build_app(services, metrics.clone());
135131
let _ = spawn_metrics_server(config.metrics.listener.clone(), metrics).await?;
136132

@@ -246,71 +242,54 @@ async fn spawn_metrics_server(
246242
Ok(actual_port)
247243
}
248244

249-
async fn route_request(
250-
State(services): State<Services>,
251-
req: axum::extract::Request,
252-
) -> Response {
245+
async fn route_request(State(services): State<Services>, req: axum::extract::Request) -> Response {
253246
if is_relay_request(&req) {
254247
let mut relay = services.relay.clone();
255248
match relay.call(req).await {
256249
Ok(res) => res.into_response(),
257-
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response(),
250+
Err(e) => (axum::http::StatusCode::BAD_GATEWAY, e.to_string()).into_response(),
258251
}
259252
} else {
253+
// The directory service handles all other requests (including 404)
260254
handle_directory_request(services, req).await
261255
}
262256
}
263257

264258
async fn handle_directory_request(services: Services, req: axum::extract::Request) -> Response {
265-
let ohttp_server = services.directory.ohttp.clone();
266-
267-
let ohttp_config = OhttpGatewayConfig::new(ohttp_server, services.sentinel_tag);
268-
269-
let (parts, body) = req.into_parts();
270-
271-
use http_body_util::BodyExt as _;
272-
273-
let body_bytes = body
274-
.collect()
275-
.await
276-
.map_err(|_| "Failed to collect body")
277-
.expect("Failed to collect body")
278-
.to_bytes();
279-
280-
let boxed_body = BoxBody::new(http_body_util::Full::new(body_bytes));
281-
282-
let hyper_req = Request::from_parts(parts, boxed_body);
259+
let is_ohttp_request = matches!(
260+
(req.method(), req.uri().path()),
261+
(&Method::POST, "/.well-known/ohttp-gateway") | (&Method::POST, "/")
262+
);
283263

284-
let directory_service = tower::service_fn({
285-
let directory = services.directory.clone();
286-
move |req: Request<BoxBody<Bytes, hyper::Error>>| {
287-
let mut dir = directory.clone();
288-
async move {
289-
dir.call(req).await.map_err(|e| {
290-
Box::new(std::io::Error::other(e.to_string()))
291-
as Box<dyn std::error::Error + Send + Sync>
292-
})
293-
}
264+
if is_ohttp_request {
265+
let app = Router::new()
266+
.fallback(directory_handler)
267+
.layer(axum::middleware::from_fn_with_state(
268+
services.ohttp_config.clone(),
269+
crate::ohttp::ohttp_gateway,
270+
))
271+
.with_state(services.directory.clone());
272+
273+
match app.oneshot(req).await {
274+
Ok(response) => response,
275+
Err(e) =>
276+
(axum::http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
294277
}
295-
});
278+
} else {
279+
directory_handler(State(services.directory), req).await
280+
}
281+
}
296282

297-
let mut service_with_ohttp = ServiceBuilder::new()
298-
.layer(OhttpGatewayLayer::new(ohttp_config))
299-
.service(directory_service)
300-
.boxed_clone();
301-
302-
match service_with_ohttp.ready().await {
303-
Ok(ready_service) => match ready_service.call(hyper_req).await {
304-
Ok(response) => {
305-
let (parts, body) = response.into_parts();
306-
let axum_body = axum::body::Body::new(body);
307-
Response::from_parts(parts, axum_body).into_response()
308-
}
309-
Err(e) =>
310-
(StatusCode::INTERNAL_SERVER_ERROR, format!("Service error: {}", e)).into_response(),
311-
},
283+
async fn directory_handler(
284+
State(directory): State<payjoin_directory::Service<payjoin_directory::FilesDb>>,
285+
req: axum::extract::Request,
286+
) -> Response {
287+
let mut dir = directory.clone();
288+
match dir.call(req).await {
289+
Ok(response) => response.into_response(),
312290
Err(e) =>
313-
(StatusCode::INTERNAL_SERVER_ERROR, format!("Service not ready: {}", e)).into_response(),
291+
(axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("Directory error: {}", e))
292+
.into_response(),
314293
}
315294
}
316295

0 commit comments

Comments
 (0)