@@ -10,40 +10,34 @@ use config::Config;
1010use ohttp_relay:: SentinelTag ;
1111use rand:: Rng ;
1212use tokio_listener:: { Listener , SystemOptions , UserOptions } ;
13- use tower:: { Service , ServiceBuilder } ;
13+ use tower:: { Service , ServiceBuilder , ServiceExt } ;
1414use tracing:: info;
15- pub mod ohttp;
16-
17- use http_body_util:: combinators:: BoxBody ;
18- use hyper:: body:: Bytes ;
19- use hyper:: { Request , StatusCode } ;
20- use ohttp:: { OhttpGatewayConfig , OhttpGatewayLayer } ;
21- use tower:: ServiceExt ;
2215
2316pub mod cli;
2417pub mod config;
2518pub mod metrics;
2619pub mod middleware;
20+ pub mod ohttp;
2721
2822use crate :: metrics:: MetricsService ;
2923use crate :: middleware:: { track_connections, track_metrics} ;
24+ use crate :: ohttp:: OhttpGatewayConfig ;
3025
3126#[ derive( Clone ) ]
3227struct Services {
3328 directory : payjoin_directory:: Service < payjoin_directory:: FilesDb > ,
3429 relay : ohttp_relay:: Service ,
35- sentinel_tag : SentinelTag ,
30+ ohttp_config : OhttpGatewayConfig ,
3631}
3732
3833pub async fn serve ( config : Config ) -> anyhow:: Result < ( ) > {
3934 let sentinel_tag = generate_sentinel_tag ( ) ;
4035 let metrics = MetricsService :: new ( ) ?;
36+ let directory = init_directory ( & config, sentinel_tag) . await ?;
37+ let ohttp_config = OhttpGatewayConfig :: new ( directory. ohttp . clone ( ) , sentinel_tag) ;
4138
42- let services = Services {
43- directory : init_directory ( & config, sentinel_tag) . await ?,
44- relay : ohttp_relay:: Service :: new ( sentinel_tag) . await ,
45- sentinel_tag,
46- } ;
39+ let services =
40+ Services { directory, relay : ohttp_relay:: Service :: new ( sentinel_tag) . await , ohttp_config } ;
4741
4842 let app = build_app ( services, metrics. clone ( ) ) ;
4943 let _ = spawn_metrics_server ( config. metrics . listener . clone ( ) , metrics) . await ?;
@@ -70,16 +64,17 @@ pub async fn serve_manual_tls(
7064 tls_config : Option < axum_server:: tls_rustls:: RustlsConfig > ,
7165 root_store : rustls:: RootCertStore ,
7266) -> anyhow:: Result < ( u16 , u16 , tokio:: task:: JoinHandle < anyhow:: Result < ( ) > > ) > {
73- use std:: net:: SocketAddr ;
74-
7567 let sentinel_tag = generate_sentinel_tag ( ) ;
7668 let metrics = MetricsService :: new ( ) ?;
69+ let directory = init_directory ( & config, sentinel_tag) . await ?;
70+ let ohttp_config = OhttpGatewayConfig :: new ( directory. ohttp . clone ( ) , sentinel_tag) ;
7771
7872 let services = Services {
79- directory : init_directory ( & config , sentinel_tag ) . await ? ,
73+ directory,
8074 relay : ohttp_relay:: Service :: new_with_roots ( root_store, sentinel_tag) . await ,
81- sentinel_tag ,
75+ ohttp_config ,
8276 } ;
77+
8378 let app = build_app ( services, metrics. clone ( ) ) ;
8479 let metrics_port = spawn_metrics_server ( config. metrics . listener . clone ( ) , metrics) . await ?;
8580
@@ -126,11 +121,12 @@ pub async fn serve_acme(config: Config) -> anyhow::Result<()> {
126121
127122 let sentinel_tag = generate_sentinel_tag ( ) ;
128123 let metrics = MetricsService :: new ( ) ?;
124+ let directory = init_directory ( & config, sentinel_tag) . await ?;
125+ let ohttp_config = OhttpGatewayConfig :: new ( directory. ohttp . clone ( ) , sentinel_tag) ;
126+
127+ let services =
128+ Services { directory, relay : ohttp_relay:: Service :: new ( sentinel_tag) . await , ohttp_config } ;
129129
130- let services = Services {
131- directory : init_directory ( & config, sentinel_tag) . await ?,
132- relay : ohttp_relay:: Service :: new ( sentinel_tag) . await ,
133- } ;
134130 let app = build_app ( services, metrics. clone ( ) ) ;
135131 let _ = spawn_metrics_server ( config. metrics . listener . clone ( ) , metrics) . await ?;
136132
@@ -246,71 +242,54 @@ async fn spawn_metrics_server(
246242 Ok ( actual_port)
247243}
248244
249- async fn route_request (
250- State ( services) : State < Services > ,
251- req : axum:: extract:: Request ,
252- ) -> Response {
245+ async fn route_request ( State ( services) : State < Services > , req : axum:: extract:: Request ) -> Response {
253246 if is_relay_request ( & req) {
254247 let mut relay = services. relay . clone ( ) ;
255248 match relay. call ( req) . await {
256249 Ok ( res) => res. into_response ( ) ,
257- Err ( e) => ( StatusCode :: BAD_GATEWAY , e. to_string ( ) ) . into_response ( ) ,
250+ Err ( e) => ( axum :: http :: StatusCode :: BAD_GATEWAY , e. to_string ( ) ) . into_response ( ) ,
258251 }
259252 } else {
253+ // The directory service handles all other requests (including 404)
260254 handle_directory_request ( services, req) . await
261255 }
262256}
263257
264258async fn handle_directory_request ( services : Services , req : axum:: extract:: Request ) -> Response {
265- let ohttp_server = services. directory . ohttp . clone ( ) ;
266-
267- let ohttp_config = OhttpGatewayConfig :: new ( ohttp_server, services. sentinel_tag ) ;
268-
269- let ( parts, body) = req. into_parts ( ) ;
270-
271- use http_body_util:: BodyExt as _;
272-
273- let body_bytes = body
274- . collect ( )
275- . await
276- . map_err ( |_| "Failed to collect body" )
277- . expect ( "Failed to collect body" )
278- . to_bytes ( ) ;
279-
280- let boxed_body = BoxBody :: new ( http_body_util:: Full :: new ( body_bytes) ) ;
281-
282- let hyper_req = Request :: from_parts ( parts, boxed_body) ;
259+ let is_ohttp_request = matches ! (
260+ ( req. method( ) , req. uri( ) . path( ) ) ,
261+ ( & Method :: POST , "/.well-known/ohttp-gateway" ) | ( & Method :: POST , "/" )
262+ ) ;
283263
284- let directory_service = tower:: service_fn ( {
285- let directory = services. directory . clone ( ) ;
286- move |req : Request < BoxBody < Bytes , hyper:: Error > > | {
287- let mut dir = directory. clone ( ) ;
288- async move {
289- dir. call ( req) . await . map_err ( |e| {
290- Box :: new ( std:: io:: Error :: other ( e. to_string ( ) ) )
291- as Box < dyn std:: error:: Error + Send + Sync >
292- } )
293- }
264+ if is_ohttp_request {
265+ let app = Router :: new ( )
266+ . fallback ( directory_handler)
267+ . layer ( axum:: middleware:: from_fn_with_state (
268+ services. ohttp_config . clone ( ) ,
269+ crate :: ohttp:: ohttp_gateway,
270+ ) )
271+ . with_state ( services. directory . clone ( ) ) ;
272+
273+ match app. oneshot ( req) . await {
274+ Ok ( response) => response,
275+ Err ( e) =>
276+ ( axum:: http:: StatusCode :: INTERNAL_SERVER_ERROR , e. to_string ( ) ) . into_response ( ) ,
294277 }
295- } ) ;
278+ } else {
279+ directory_handler ( State ( services. directory ) , req) . await
280+ }
281+ }
296282
297- let mut service_with_ohttp = ServiceBuilder :: new ( )
298- . layer ( OhttpGatewayLayer :: new ( ohttp_config) )
299- . service ( directory_service)
300- . boxed_clone ( ) ;
301-
302- match service_with_ohttp. ready ( ) . await {
303- Ok ( ready_service) => match ready_service. call ( hyper_req) . await {
304- Ok ( response) => {
305- let ( parts, body) = response. into_parts ( ) ;
306- let axum_body = axum:: body:: Body :: new ( body) ;
307- Response :: from_parts ( parts, axum_body) . into_response ( )
308- }
309- Err ( e) =>
310- ( StatusCode :: INTERNAL_SERVER_ERROR , format ! ( "Service error: {}" , e) ) . into_response ( ) ,
311- } ,
283+ async fn directory_handler (
284+ State ( directory) : State < payjoin_directory:: Service < payjoin_directory:: FilesDb > > ,
285+ req : axum:: extract:: Request ,
286+ ) -> Response {
287+ let mut dir = directory. clone ( ) ;
288+ match dir. call ( req) . await {
289+ Ok ( response) => response. into_response ( ) ,
312290 Err ( e) =>
313- ( StatusCode :: INTERNAL_SERVER_ERROR , format ! ( "Service not ready: {}" , e) ) . into_response ( ) ,
291+ ( axum:: http:: StatusCode :: INTERNAL_SERVER_ERROR , format ! ( "Directory error: {}" , e) )
292+ . into_response ( ) ,
314293 }
315294}
316295
0 commit comments