@@ -10,40 +10,34 @@ use config::Config;
1010use ohttp_relay:: SentinelTag ;
1111use rand:: Rng ;
1212use tokio_listener:: { Listener , SystemOptions , UserOptions } ;
13- use tower:: { Service , ServiceBuilder } ;
13+ use tower:: { Service , ServiceBuilder , ServiceExt } ;
1414use tracing:: info;
15- pub mod ohttp;
16-
17- use http_body_util:: combinators:: BoxBody ;
18- use hyper:: body:: Bytes ;
19- use hyper:: { Request , StatusCode } ;
20- use ohttp:: { OhttpGatewayConfig , OhttpGatewayLayer } ;
21- use tower:: ServiceExt ;
2215
2316pub mod cli;
2417pub mod config;
2518pub mod metrics;
2619pub mod middleware;
20+ pub mod ohttp;
2721
2822use crate :: metrics:: MetricsService ;
2923use crate :: middleware:: { track_connections, track_metrics} ;
24+ use crate :: ohttp:: OhttpGatewayConfig ;
3025
3126#[ derive( Clone ) ]
3227struct Services {
3328 directory : payjoin_directory:: Service < payjoin_directory:: FilesDb > ,
3429 relay : ohttp_relay:: Service ,
35- sentinel_tag : SentinelTag ,
30+ ohttp_config : OhttpGatewayConfig ,
3631}
3732
3833pub async fn serve ( config : Config ) -> anyhow:: Result < ( ) > {
3934 let sentinel_tag = generate_sentinel_tag ( ) ;
4035 let metrics = MetricsService :: new ( ) ?;
36+ let directory = init_directory ( & config, sentinel_tag) . await ?;
37+ let ohttp_config = OhttpGatewayConfig :: new ( directory. ohttp . clone ( ) , sentinel_tag) ;
4138
42- let services = Services {
43- directory : init_directory ( & config, sentinel_tag) . await ?,
44- relay : ohttp_relay:: Service :: new ( sentinel_tag) . await ,
45- sentinel_tag,
46- } ;
39+ let services =
40+ Services { directory, relay : ohttp_relay:: Service :: new ( sentinel_tag) . await , ohttp_config } ;
4741
4842 let app = build_app ( services, metrics. clone ( ) ) ;
4943 let _ = spawn_metrics_server ( config. metrics . listener . clone ( ) , metrics) . await ?;
@@ -70,16 +64,17 @@ pub async fn serve_manual_tls(
7064 tls_config : Option < axum_server:: tls_rustls:: RustlsConfig > ,
7165 root_store : rustls:: RootCertStore ,
7266) -> anyhow:: Result < ( u16 , u16 , tokio:: task:: JoinHandle < anyhow:: Result < ( ) > > ) > {
73- use std:: net:: SocketAddr ;
74-
7567 let sentinel_tag = generate_sentinel_tag ( ) ;
7668 let metrics = MetricsService :: new ( ) ?;
69+ let directory = init_directory ( & config, sentinel_tag) . await ?;
70+ let ohttp_config = OhttpGatewayConfig :: new ( directory. ohttp . clone ( ) , sentinel_tag) ;
7771
7872 let services = Services {
79- directory : init_directory ( & config , sentinel_tag ) . await ? ,
73+ directory,
8074 relay : ohttp_relay:: Service :: new_with_roots ( root_store, sentinel_tag) . await ,
81- sentinel_tag ,
75+ ohttp_config ,
8276 } ;
77+
8378 let app = build_app ( services, metrics. clone ( ) ) ;
8479 let metrics_port = spawn_metrics_server ( config. metrics . listener . clone ( ) , metrics) . await ?;
8580
@@ -116,7 +111,6 @@ pub async fn serve_manual_tls(
116111/// certificates from Let's Encrypt via the TLS-ALPN-01 challenge.
117112#[ cfg( feature = "acme" ) ]
118113pub async fn serve_acme ( config : Config ) -> anyhow:: Result < ( ) > {
119- use std:: net:: SocketAddr ;
120114 use std:: sync:: Arc ;
121115
122116 let acme_config = config
@@ -126,11 +120,12 @@ pub async fn serve_acme(config: Config) -> anyhow::Result<()> {
126120
127121 let sentinel_tag = generate_sentinel_tag ( ) ;
128122 let metrics = MetricsService :: new ( ) ?;
123+ let directory = init_directory ( & config, sentinel_tag) . await ?;
124+ let ohttp_config = OhttpGatewayConfig :: new ( directory. ohttp . clone ( ) , sentinel_tag) ;
125+
126+ let services =
127+ Services { directory, relay : ohttp_relay:: Service :: new ( sentinel_tag) . await , ohttp_config } ;
129128
130- let services = Services {
131- directory : init_directory ( & config, sentinel_tag) . await ?,
132- relay : ohttp_relay:: Service :: new ( sentinel_tag) . await ,
133- } ;
134129 let app = build_app ( services, metrics. clone ( ) ) ;
135130 let _ = spawn_metrics_server ( config. metrics . listener . clone ( ) , metrics) . await ?;
136131
@@ -246,71 +241,54 @@ async fn spawn_metrics_server(
246241 Ok ( actual_port)
247242}
248243
249- async fn route_request (
250- State ( services) : State < Services > ,
251- req : axum:: extract:: Request ,
252- ) -> Response {
244+ async fn route_request ( State ( services) : State < Services > , req : axum:: extract:: Request ) -> Response {
253245 if is_relay_request ( & req) {
254246 let mut relay = services. relay . clone ( ) ;
255247 match relay. call ( req) . await {
256248 Ok ( res) => res. into_response ( ) ,
257- Err ( e) => ( StatusCode :: BAD_GATEWAY , e. to_string ( ) ) . into_response ( ) ,
249+ Err ( e) => ( axum :: http :: StatusCode :: BAD_GATEWAY , e. to_string ( ) ) . into_response ( ) ,
258250 }
259251 } else {
252+ // The directory service handles all other requests (including 404)
260253 handle_directory_request ( services, req) . await
261254 }
262255}
263256
264257async fn handle_directory_request ( services : Services , req : axum:: extract:: Request ) -> Response {
265- let ohttp_server = services. directory . ohttp . clone ( ) ;
266-
267- let ohttp_config = OhttpGatewayConfig :: new ( ohttp_server, services. sentinel_tag ) ;
268-
269- let ( parts, body) = req. into_parts ( ) ;
270-
271- use http_body_util:: BodyExt as _;
272-
273- let body_bytes = body
274- . collect ( )
275- . await
276- . map_err ( |_| "Failed to collect body" )
277- . expect ( "Failed to collect body" )
278- . to_bytes ( ) ;
279-
280- let boxed_body = BoxBody :: new ( http_body_util:: Full :: new ( body_bytes) ) ;
281-
282- let hyper_req = Request :: from_parts ( parts, boxed_body) ;
258+ let is_ohttp_request = matches ! (
259+ ( req. method( ) , req. uri( ) . path( ) ) ,
260+ ( & Method :: POST , "/.well-known/ohttp-gateway" ) | ( & Method :: POST , "/" )
261+ ) ;
283262
284- let directory_service = tower:: service_fn ( {
285- let directory = services. directory . clone ( ) ;
286- move |req : Request < BoxBody < Bytes , hyper:: Error > > | {
287- let mut dir = directory. clone ( ) ;
288- async move {
289- dir. call ( req) . await . map_err ( |e| {
290- Box :: new ( std:: io:: Error :: other ( e. to_string ( ) ) )
291- as Box < dyn std:: error:: Error + Send + Sync >
292- } )
293- }
263+ if is_ohttp_request {
264+ let app = Router :: new ( )
265+ . fallback ( directory_handler)
266+ . layer ( axum:: middleware:: from_fn_with_state (
267+ services. ohttp_config . clone ( ) ,
268+ crate :: ohttp:: ohttp_gateway,
269+ ) )
270+ . with_state ( services. directory . clone ( ) ) ;
271+
272+ match app. oneshot ( req) . await {
273+ Ok ( response) => response,
274+ Err ( e) =>
275+ ( axum:: http:: StatusCode :: INTERNAL_SERVER_ERROR , e. to_string ( ) ) . into_response ( ) ,
294276 }
295- } ) ;
277+ } else {
278+ directory_handler ( State ( services. directory ) , req) . await
279+ }
280+ }
296281
297- let mut service_with_ohttp = ServiceBuilder :: new ( )
298- . layer ( OhttpGatewayLayer :: new ( ohttp_config) )
299- . service ( directory_service)
300- . boxed_clone ( ) ;
301-
302- match service_with_ohttp. ready ( ) . await {
303- Ok ( ready_service) => match ready_service. call ( hyper_req) . await {
304- Ok ( response) => {
305- let ( parts, body) = response. into_parts ( ) ;
306- let axum_body = axum:: body:: Body :: new ( body) ;
307- Response :: from_parts ( parts, axum_body) . into_response ( )
308- }
309- Err ( e) =>
310- ( StatusCode :: INTERNAL_SERVER_ERROR , format ! ( "Service error: {}" , e) ) . into_response ( ) ,
311- } ,
282+ async fn directory_handler (
283+ State ( directory) : State < payjoin_directory:: Service < payjoin_directory:: FilesDb > > ,
284+ req : axum:: extract:: Request ,
285+ ) -> Response {
286+ let mut dir = directory. clone ( ) ;
287+ match dir. call ( req) . await {
288+ Ok ( response) => response. into_response ( ) ,
312289 Err ( e) =>
313- ( StatusCode :: INTERNAL_SERVER_ERROR , format ! ( "Service not ready: {}" , e) ) . into_response ( ) ,
290+ ( axum:: http:: StatusCode :: INTERNAL_SERVER_ERROR , format ! ( "Directory error: {}" , e) )
291+ . into_response ( ) ,
314292 }
315293}
316294
0 commit comments