1- use blitz_traits:: net:: { BoxedHandler , Bytes , NetCallback , NetProvider , Request , SharedCallback } ;
1+ use blitz_traits:: net:: {
2+ AbortSignal , BoxedHandler , Bytes , NetCallback , NetProvider , Request , SharedCallback ,
3+ } ;
24use data_url:: DataUrl ;
35use reqwest:: Client ;
4- use std:: sync:: Arc ;
6+ use std:: { marker :: PhantomData , pin :: Pin , sync:: Arc , task :: Poll } ;
57use tokio:: {
68 runtime:: Handle ,
79 sync:: mpsc:: { UnboundedReceiver , UnboundedSender , unbounded_channel} ,
@@ -75,18 +77,6 @@ impl<D: 'static> Provider<D> {
7577 } )
7678 }
7779
78- async fn fetch_with_handler (
79- client : Client ,
80- doc_id : usize ,
81- request : Request ,
82- handler : BoxedHandler < D > ,
83- res_callback : SharedCallback < D > ,
84- ) -> Result < ( ) , ProviderError > {
85- let ( _response_url, bytes) = Self :: fetch_inner ( client, request) . await ?;
86- handler. bytes ( doc_id, bytes, res_callback) ;
87- Ok ( ( ) )
88- }
89-
9080 #[ allow( clippy:: type_complexity) ]
9181 pub fn fetch_with_callback (
9282 & self ,
@@ -108,24 +98,78 @@ impl<D: 'static> Provider<D> {
10898}
10999
110100impl < D : ' static > NetProvider < D > for Provider < D > {
111- fn fetch ( & self , doc_id : usize , request : Request , handler : BoxedHandler < D > ) {
101+ fn fetch ( & self , doc_id : usize , mut request : Request , handler : BoxedHandler < D > ) {
112102 let client = self . client . clone ( ) ;
113103 let callback = Arc :: clone ( & self . resource_callback ) ;
114104 println ! ( "Fetching {}" , & request. url) ;
115105 self . rt . spawn ( async move {
116106 let url = request. url . to_string ( ) ;
117- let res = Self :: fetch_with_handler ( client, doc_id, request, handler, callback) . await ;
118- if let Err ( e) = res {
119- eprintln ! ( "Error fetching {url}: {e:?}" ) ;
107+ let signal = request. signal . take ( ) ;
108+ let result = if let Some ( signal) = signal {
109+ AbortFetch :: new (
110+ signal,
111+ Box :: pin ( async move { Self :: fetch_inner ( client, request) . await } ) ,
112+ )
113+ . await
120114 } else {
121- println ! ( "Success {url}" ) ;
115+ Self :: fetch_inner ( client, request) . await
116+ } ;
117+
118+ match result {
119+ Ok ( ( _response_url, bytes) ) => {
120+ handler. bytes ( doc_id, bytes, callback) ;
121+ println ! ( "Success {url}" ) ;
122+ }
123+ Err ( e) => {
124+ eprintln ! ( "Error fetching {url}: {e:?}" ) ;
125+ }
122126 }
123127 } ) ;
124128 }
125129}
126130
131+ struct AbortFetch < F , T > {
132+ signal : AbortSignal ,
133+ future : F ,
134+ _rt : PhantomData < T > ,
135+ }
136+
137+ impl < F , T > AbortFetch < F , T > {
138+ fn new ( signal : AbortSignal , future : F ) -> Self {
139+ Self {
140+ signal,
141+ future,
142+ _rt : PhantomData ,
143+ }
144+ }
145+ }
146+
147+ impl < F , T > Future for AbortFetch < F , T >
148+ where
149+ F : Future + Unpin + Send + ' static ,
150+ F :: Output : Send + Into < Result < T , ProviderError > > + ' static ,
151+ T : Unpin ,
152+ {
153+ type Output = Result < T , ProviderError > ;
154+
155+ fn poll (
156+ mut self : std:: pin:: Pin < & mut Self > ,
157+ cx : & mut std:: task:: Context < ' _ > ,
158+ ) -> std:: task:: Poll < Self :: Output > {
159+ if self . signal . aborted ( ) {
160+ return Poll :: Ready ( Err ( ProviderError :: Abort ) ) ;
161+ }
162+
163+ match Pin :: new ( & mut self . future ) . poll ( cx) {
164+ Poll :: Ready ( output) => Poll :: Ready ( output. into ( ) ) ,
165+ Poll :: Pending => Poll :: Pending ,
166+ }
167+ }
168+ }
169+
127170#[ derive( Debug ) ]
128171pub enum ProviderError {
172+ Abort ,
129173 Io ( std:: io:: Error ) ,
130174 DataUrl ( data_url:: DataUrlError ) ,
131175 DataUrlBase64 ( data_url:: forgiving_base64:: InvalidBase64 ) ,
0 commit comments