@@ -32,24 +32,16 @@ pub async fn main_entry(config: Config, shutdown_token: tokio_util::sync::Cancel
3232 let timeout = Duration :: from_secs ( config. timeout ) ;
3333
3434 let cache = create_dns_cache ( ) ;
35-
36- fn handle_error ( res : Result < Result < ( ) , Error > , tokio:: task:: JoinError > , protocol : & str ) {
37- match res {
38- Ok ( Err ( e) ) => log:: error!( "{} error \" {}\" " , protocol, e) ,
39- Err ( e) => log:: error!( "{} error \" {}\" " , protocol, e) ,
40- _ => { }
41- }
42- }
43-
35+ let shutdown_for_select = shutdown_token. clone ( ) ;
4436 tokio:: select! {
45- _ = shutdown_token . cancelled( ) => {
37+ _ = shutdown_for_select . cancelled( ) => {
4638 log:: info!( "Shutdown received" ) ;
4739 } ,
48- res = tokio :: spawn ( udp_thread( config. clone( ) , user_key. clone( ) , cache. clone( ) , timeout) ) => {
49- handle_error ( res, "UDP" ) ;
40+ res = udp_thread( config. clone( ) , user_key. clone( ) , cache. clone( ) , timeout, shutdown_token . clone ( ) ) => {
41+ res? ;
5042 } ,
51- res = tokio :: spawn ( tcp_thread( config, user_key, cache, timeout) ) => {
52- handle_error ( res, "TCP" ) ;
43+ res = tcp_thread( config, user_key, cache, timeout, shutdown_token ) => {
44+ res? ;
5345 } ,
5446 }
5547
@@ -58,7 +50,13 @@ pub async fn main_entry(config: Config, shutdown_token: tokio_util::sync::Cancel
5850 Ok ( ( ) )
5951}
6052
61- pub ( crate ) async fn udp_thread ( opt : Config , user_key : Option < UserKey > , cache : Cache < Vec < Query > , Message > , timeout : Duration ) -> Result < ( ) > {
53+ pub ( crate ) async fn udp_thread (
54+ opt : Config ,
55+ user_key : Option < UserKey > ,
56+ cache : Cache < Vec < Query > , Message > ,
57+ timeout : Duration ,
58+ shutdown_token : tokio_util:: sync:: CancellationToken ,
59+ ) -> Result < ( ) > {
6260 let listener = match UdpSocket :: bind ( & opt. listen_addr ) . await {
6361 Ok ( listener) => listener,
6462 Err ( e) => {
@@ -74,19 +72,27 @@ pub(crate) async fn udp_thread(opt: Config, user_key: Option<UserKey>, cache: Ca
7472 let opt = opt. clone ( ) ;
7573 let cache = cache. clone ( ) ;
7674 let auth = user_key. clone ( ) ;
77- let block = async move {
78- let mut buf = vec ! [ 0u8 ; MAX_BUFFER_SIZE ] ;
79- let ( len, src) = listener. recv_from ( & mut buf) . await ?;
80- buf. resize ( len, 0 ) ;
81- tokio:: spawn ( async move {
82- if let Err ( e) = udp_incoming_handler ( listener, buf, src, opt, cache, auth, timeout) . await {
83- log:: error!( "DNS query via UDP incoming handler error \" {}\" " , e) ;
75+ tokio:: select! {
76+ _ = shutdown_token. cancelled( ) => {
77+ log:: info!( "UDP shutdown received" ) ;
78+ return Ok ( ( ) ) ;
79+ }
80+ res = async move {
81+ let mut buf = vec![ 0u8 ; MAX_BUFFER_SIZE ] ;
82+ let ( len, src) = listener. recv_from( & mut buf) . await ?;
83+ buf. resize( len, 0 ) ;
84+ tokio:: spawn( async move {
85+ if let Err ( e) = udp_incoming_handler( listener, buf, src, opt, cache, auth, timeout) . await {
86+ log:: error!( "DNS query via UDP incoming handler error \" {}\" " , e) ;
87+ }
88+ } ) ;
89+ Ok :: <( ) , Error >( ( ) )
90+ } => {
91+ if let Err ( e) = res {
92+ log:: error!( "UDP listener error \" {}\" " , e) ;
93+ return Err ( e) ;
8494 }
85- } ) ;
86- Ok :: < ( ) , Error > ( ( ) )
87- } ;
88- if let Err ( e) = block. await {
89- log:: error!( "UDP listener error \" {}\" " , e) ;
95+ }
9096 }
9197 }
9298}
@@ -142,7 +148,13 @@ async fn udp_incoming_handler(
142148 Ok :: < ( ) , Error > ( ( ) )
143149}
144150
145- pub ( crate ) async fn tcp_thread ( opt : Config , user_key : Option < UserKey > , cache : Cache < Vec < Query > , Message > , timeout : Duration ) -> Result < ( ) > {
151+ pub ( crate ) async fn tcp_thread (
152+ opt : Config ,
153+ user_key : Option < UserKey > ,
154+ cache : Cache < Vec < Query > , Message > ,
155+ timeout : Duration ,
156+ shutdown_token : tokio_util:: sync:: CancellationToken ,
157+ ) -> Result < ( ) > {
146158 let listener = match TcpListener :: bind ( & opt. listen_addr ) . await {
147159 Ok ( listener) => listener,
148160 Err ( e) => {
@@ -152,17 +164,31 @@ pub(crate) async fn tcp_thread(opt: Config, user_key: Option<UserKey>, cache: Ca
152164 } ;
153165 log:: info!( "TCP listening on: {}" , opt. listen_addr) ;
154166
155- while let Ok ( ( mut incoming, _) ) = listener. accept ( ) . await {
156- let opt = opt. clone ( ) ;
157- let user_key = user_key. clone ( ) ;
158- let cache = cache. clone ( ) ;
159- tokio:: spawn ( async move {
160- if let Err ( e) = handle_tcp_incoming ( & opt, user_key, cache, & mut incoming, timeout) . await {
161- log:: error!( "TCP error \" {}\" " , e) ;
167+ loop {
168+ tokio:: select! {
169+ _ = shutdown_token. cancelled( ) => {
170+ log:: info!( "TCP shutdown received" ) ;
171+ return Ok ( ( ) ) ;
172+ }
173+ res = listener. accept( ) => {
174+ let ( mut incoming, _) = match res {
175+ Ok ( conn) => conn,
176+ Err ( e) => {
177+ log:: error!( "TCP listener {} error \" {}\" " , opt. listen_addr, e) ;
178+ return Err ( e. into( ) ) ;
179+ }
180+ } ;
181+ let opt = opt. clone( ) ;
182+ let user_key = user_key. clone( ) ;
183+ let cache = cache. clone( ) ;
184+ tokio:: spawn( async move {
185+ if let Err ( e) = handle_tcp_incoming( & opt, user_key, cache, & mut incoming, timeout) . await {
186+ log:: error!( "TCP error \" {}\" " , e) ;
187+ }
188+ } ) ;
162189 }
163- } ) ;
190+ } ;
164191 }
165- Ok ( ( ) )
166192}
167193
168194async fn handle_tcp_incoming (
@@ -172,10 +198,16 @@ async fn handle_tcp_incoming(
172198 incoming : & mut TcpStream ,
173199 timeout : Duration ,
174200) -> Result < ( ) > {
175- let mut buf = [ 0u8 ; MAX_BUFFER_SIZE ] ;
176- let n = tokio:: time:: timeout ( timeout, incoming. read ( & mut buf) ) . await ??;
201+ let mut len_buf = [ 0u8 ; 2 ] ;
202+ tokio:: time:: timeout ( timeout, incoming. read_exact ( & mut len_buf) ) . await ??;
203+ let len = u16:: from_be_bytes ( len_buf) as usize ;
204+ let mut msg_buf = vec ! [ 0u8 ; len] ;
205+ tokio:: time:: timeout ( timeout, incoming. read_exact ( & mut msg_buf) ) . await ??;
206+
207+ let mut buf = len_buf. to_vec ( ) ;
208+ buf. extend ( msg_buf) ;
177209
178- let message = dns:: parse_data_to_dns_message ( & buf[ ..n ] , true ) ?;
210+ let message = dns:: parse_data_to_dns_message ( & buf, true ) ?;
179211 let domain = dns:: extract_domain_from_dns_message ( & message) ?;
180212
181213 if opt. cache_records
@@ -191,7 +223,7 @@ async fn handle_tcp_incoming(
191223
192224 let proxy_addr = opt. socks5_settings . addr ;
193225 let target_server = opt. dns_remote_server ;
194- let response_buf = tcp_via_socks5_server ( proxy_addr, target_server, auth, & buf[ ..n ] , timeout) . await ?;
226+ let response_buf = tcp_via_socks5_server ( proxy_addr, target_server, auth, & buf, timeout) . await ?;
195227
196228 incoming. write_all ( & response_buf) . await ?;
197229
@@ -216,9 +248,9 @@ where
216248 A : ToSocketAddrs ,
217249 B : Into < Address > ,
218250{
219- let s5_proxy = TcpStream :: connect ( proxy_addr) . await ?;
251+ let s5_proxy = tokio :: time :: timeout ( timeout , TcpStream :: connect ( proxy_addr) ) . await ? ?;
220252 let mut stream = BufStream :: new ( s5_proxy) ;
221- let _addr = client:: connect ( & mut stream, target_server, auth) . await ?;
253+ let _addr = tokio :: time :: timeout ( timeout , client:: connect ( & mut stream, target_server, auth) ) . await ? ?;
222254
223255 stream. write_all ( buf) . await ?;
224256 stream. flush ( ) . await ?;
0 commit comments