Skip to content

Commit 7c3eef5

Browse files
committed
Fix DNS2SOCKS shutdown and loop handling
1 parent 6b204a3 commit 7c3eef5

3 files changed

Lines changed: 95 additions & 48 deletions

File tree

src/api.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ pub unsafe extern "C" fn dns2socks_start(
2424
verbosity: ArgVerbosity,
2525
timeout: i32,
2626
) -> c_int {
27+
struct TunQuitGuard;
28+
29+
impl Drop for TunQuitGuard {
30+
fn drop(&mut self) {
31+
if let Ok(mut lock) = TUN_QUIT.lock() {
32+
*lock = None;
33+
}
34+
}
35+
}
36+
37+
let _guard = TunQuitGuard;
2738
let shutdown_token = tokio_util::sync::CancellationToken::new();
2839
{
2940
if let Ok(mut lock) = TUN_QUIT.lock() {
@@ -41,6 +52,8 @@ pub unsafe extern "C" fn dns2socks_start(
4152
log::warn!("set logger error: {}", err);
4253
}
4354

55+
let timeout = if timeout > 0 { timeout } else { 5 };
56+
4457
let mut config = crate::Config::default();
4558
config
4659
.verbosity(verbosity)

src/dump_logger.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::ArgVerbosity;
22
use std::os::raw::{c_char, c_void};
33

4-
static DUMP_CALLBACK: std::sync::OnceLock<Option<DumpCallback>> = std::sync::OnceLock::new();
4+
static DUMP_CALLBACK: std::sync::Mutex<Option<DumpCallback>> = std::sync::Mutex::new(None);
55

66
/// # Safety
77
///
@@ -11,10 +11,11 @@ pub unsafe extern "C" fn dns2socks_set_log_callback(
1111
callback: Option<unsafe extern "C" fn(ArgVerbosity, *const c_char, *mut c_void)>,
1212
ctx: *mut c_void,
1313
) {
14-
if let Some(_cb) = DUMP_CALLBACK.get_or_init(|| Some(DumpCallback(callback, ctx))) {
14+
if let Ok(mut lock) = DUMP_CALLBACK.lock() {
15+
*lock = Some(DumpCallback(callback, ctx));
1516
log::info!("dump log callback set success");
1617
} else {
17-
log::warn!("dump log callback already set");
18+
log::warn!("dump log callback set failed");
1819
}
1920
}
2021

@@ -66,8 +67,9 @@ impl DumpLogger {
6667
return;
6768
};
6869
let ptr = c_msg.as_ptr();
69-
if let Some(Some(cb)) = DUMP_CALLBACK.get() {
70-
unsafe { cb.clone().call(record.level().into(), ptr) };
70+
let callback = DUMP_CALLBACK.lock().ok().and_then(|lock| lock.clone());
71+
if let Some(cb) = callback {
72+
unsafe { cb.call(record.level().into(), ptr) };
7173
}
7274
}
7375
}

src/lib.rs

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,16 @@ pub async fn main_entry(config: Config, shutdown_token: tokio_util::sync::Cancel
3232
let timeout = Duration::from_secs(config.timeout);
3333

3434
let cache = create_dns_cache();
35-
36-
fn handle_error(res: Result<Result<(), Error>, tokio::task::JoinError>, protocol: &str) {
37-
match res {
38-
Ok(Err(e)) => log::error!("{} error \"{}\"", protocol, e),
39-
Err(e) => log::error!("{} error \"{}\"", protocol, e),
40-
_ => {}
41-
}
42-
}
43-
35+
let shutdown_for_select = shutdown_token.clone();
4436
tokio::select! {
45-
_ = shutdown_token.cancelled() => {
37+
_ = shutdown_for_select.cancelled() => {
4638
log::info!("Shutdown received");
4739
},
48-
res = tokio::spawn(udp_thread(config.clone(), user_key.clone(), cache.clone(), timeout)) => {
49-
handle_error(res, "UDP");
40+
res = udp_thread(config.clone(), user_key.clone(), cache.clone(), timeout, shutdown_token.clone()) => {
41+
res?;
5042
},
51-
res = tokio::spawn(tcp_thread(config, user_key, cache, timeout)) => {
52-
handle_error(res, "TCP");
43+
res = tcp_thread(config, user_key, cache, timeout, shutdown_token) => {
44+
res?;
5345
},
5446
}
5547

@@ -58,7 +50,13 @@ pub async fn main_entry(config: Config, shutdown_token: tokio_util::sync::Cancel
5850
Ok(())
5951
}
6052

61-
pub(crate) async fn udp_thread(opt: Config, user_key: Option<UserKey>, cache: Cache<Vec<Query>, Message>, timeout: Duration) -> Result<()> {
53+
pub(crate) async fn udp_thread(
54+
opt: Config,
55+
user_key: Option<UserKey>,
56+
cache: Cache<Vec<Query>, Message>,
57+
timeout: Duration,
58+
shutdown_token: tokio_util::sync::CancellationToken,
59+
) -> Result<()> {
6260
let listener = match UdpSocket::bind(&opt.listen_addr).await {
6361
Ok(listener) => listener,
6462
Err(e) => {
@@ -74,19 +72,27 @@ pub(crate) async fn udp_thread(opt: Config, user_key: Option<UserKey>, cache: Ca
7472
let opt = opt.clone();
7573
let cache = cache.clone();
7674
let auth = user_key.clone();
77-
let block = async move {
78-
let mut buf = vec![0u8; MAX_BUFFER_SIZE];
79-
let (len, src) = listener.recv_from(&mut buf).await?;
80-
buf.resize(len, 0);
81-
tokio::spawn(async move {
82-
if let Err(e) = udp_incoming_handler(listener, buf, src, opt, cache, auth, timeout).await {
83-
log::error!("DNS query via UDP incoming handler error \"{}\"", e);
75+
tokio::select! {
76+
_ = shutdown_token.cancelled() => {
77+
log::info!("UDP shutdown received");
78+
return Ok(());
79+
}
80+
res = async move {
81+
let mut buf = vec![0u8; MAX_BUFFER_SIZE];
82+
let (len, src) = listener.recv_from(&mut buf).await?;
83+
buf.resize(len, 0);
84+
tokio::spawn(async move {
85+
if let Err(e) = udp_incoming_handler(listener, buf, src, opt, cache, auth, timeout).await {
86+
log::error!("DNS query via UDP incoming handler error \"{}\"", e);
87+
}
88+
});
89+
Ok::<(), Error>(())
90+
} => {
91+
if let Err(e) = res {
92+
log::error!("UDP listener error \"{}\"", e);
93+
return Err(e);
8494
}
85-
});
86-
Ok::<(), Error>(())
87-
};
88-
if let Err(e) = block.await {
89-
log::error!("UDP listener error \"{}\"", e);
95+
}
9096
}
9197
}
9298
}
@@ -142,7 +148,13 @@ async fn udp_incoming_handler(
142148
Ok::<(), Error>(())
143149
}
144150

145-
pub(crate) async fn tcp_thread(opt: Config, user_key: Option<UserKey>, cache: Cache<Vec<Query>, Message>, timeout: Duration) -> Result<()> {
151+
pub(crate) async fn tcp_thread(
152+
opt: Config,
153+
user_key: Option<UserKey>,
154+
cache: Cache<Vec<Query>, Message>,
155+
timeout: Duration,
156+
shutdown_token: tokio_util::sync::CancellationToken,
157+
) -> Result<()> {
146158
let listener = match TcpListener::bind(&opt.listen_addr).await {
147159
Ok(listener) => listener,
148160
Err(e) => {
@@ -152,17 +164,31 @@ pub(crate) async fn tcp_thread(opt: Config, user_key: Option<UserKey>, cache: Ca
152164
};
153165
log::info!("TCP listening on: {}", opt.listen_addr);
154166

155-
while let Ok((mut incoming, _)) = listener.accept().await {
156-
let opt = opt.clone();
157-
let user_key = user_key.clone();
158-
let cache = cache.clone();
159-
tokio::spawn(async move {
160-
if let Err(e) = handle_tcp_incoming(&opt, user_key, cache, &mut incoming, timeout).await {
161-
log::error!("TCP error \"{}\"", e);
167+
loop {
168+
tokio::select! {
169+
_ = shutdown_token.cancelled() => {
170+
log::info!("TCP shutdown received");
171+
return Ok(());
172+
}
173+
res = listener.accept() => {
174+
let (mut incoming, _) = match res {
175+
Ok(conn) => conn,
176+
Err(e) => {
177+
log::error!("TCP listener {} error \"{}\"", opt.listen_addr, e);
178+
return Err(e.into());
179+
}
180+
};
181+
let opt = opt.clone();
182+
let user_key = user_key.clone();
183+
let cache = cache.clone();
184+
tokio::spawn(async move {
185+
if let Err(e) = handle_tcp_incoming(&opt, user_key, cache, &mut incoming, timeout).await {
186+
log::error!("TCP error \"{}\"", e);
187+
}
188+
});
162189
}
163-
});
190+
};
164191
}
165-
Ok(())
166192
}
167193

168194
async fn handle_tcp_incoming(
@@ -172,10 +198,16 @@ async fn handle_tcp_incoming(
172198
incoming: &mut TcpStream,
173199
timeout: Duration,
174200
) -> Result<()> {
175-
let mut buf = [0u8; MAX_BUFFER_SIZE];
176-
let n = tokio::time::timeout(timeout, incoming.read(&mut buf)).await??;
201+
let mut len_buf = [0u8; 2];
202+
tokio::time::timeout(timeout, incoming.read_exact(&mut len_buf)).await??;
203+
let len = u16::from_be_bytes(len_buf) as usize;
204+
let mut msg_buf = vec![0u8; len];
205+
tokio::time::timeout(timeout, incoming.read_exact(&mut msg_buf)).await??;
206+
207+
let mut buf = len_buf.to_vec();
208+
buf.extend(msg_buf);
177209

178-
let message = dns::parse_data_to_dns_message(&buf[..n], true)?;
210+
let message = dns::parse_data_to_dns_message(&buf, true)?;
179211
let domain = dns::extract_domain_from_dns_message(&message)?;
180212

181213
if opt.cache_records
@@ -191,7 +223,7 @@ async fn handle_tcp_incoming(
191223

192224
let proxy_addr = opt.socks5_settings.addr;
193225
let target_server = opt.dns_remote_server;
194-
let response_buf = tcp_via_socks5_server(proxy_addr, target_server, auth, &buf[..n], timeout).await?;
226+
let response_buf = tcp_via_socks5_server(proxy_addr, target_server, auth, &buf, timeout).await?;
195227

196228
incoming.write_all(&response_buf).await?;
197229

@@ -216,9 +248,9 @@ where
216248
A: ToSocketAddrs,
217249
B: Into<Address>,
218250
{
219-
let s5_proxy = TcpStream::connect(proxy_addr).await?;
251+
let s5_proxy = tokio::time::timeout(timeout, TcpStream::connect(proxy_addr)).await??;
220252
let mut stream = BufStream::new(s5_proxy);
221-
let _addr = client::connect(&mut stream, target_server, auth).await?;
253+
let _addr = tokio::time::timeout(timeout, client::connect(&mut stream, target_server, auth)).await??;
222254

223255
stream.write_all(buf).await?;
224256
stream.flush().await?;

0 commit comments

Comments
 (0)