From 4e30ee8d1cdf224dc45366523b71002fa3ef0cd2 Mon Sep 17 00:00:00 2001 From: RustDesk <71636191+rustdesk@users.noreply.github.com> Date: Fri, 3 Apr 2026 23:13:05 +0800 Subject: [PATCH] tcp proxy (#14633) * tcp proxy * fix per review * fix per review * Suppress secure_tcp info logs for TCP proxy requests Signed-off-by: 21pages * copilot review: redact tcp proxy logs, dedupe headers, and avoid body clone Signed-off-by: 21pages * format common.rs Signed-off-by: 21pages * copilot review: test function name Signed-off-by: 21pages * copilot review: format IPv6 tcp proxy log targets correctly Signed-off-by: 21pages * copilot review: normalize HTTP method before direct request dispatch Signed-off-by: 21pages * review: extract fallback helper, fix Content-Type override, add overall timeout - Extract duplicated TCP proxy fallback logic into generic `with_tcp_proxy_fallback` helper used by both `post_request` and `http_request_sync`, eliminating code drift risk - Allow caller-supplied Content-Type to override the default in `parse_simple_header` instead of silently dropping it - Take body by reference in `post_request_http` to avoid eager clone when no fallback is needed - Wrap entire `tcp_proxy_request` flow (connect + handshake + send + receive) in an overall timeout to prevent indefinite stalls Co-Authored-By: Claude Opus 4.6 * review: make is_public case-insensitive and cover mixed-case rustdesk URLs Signed-off-by: 21pages * oidc: route auth requests through shared HTTP/tcp-proxy path while keeping TLS warmup Signed-off-by: 21pages * refactor: replace unused TryFrom with HbbHttpResponse::parse method Remove TryFrom impl that was never called and replace the private parse_hbb_http_response helper in account.rs with a public parse() method on HbbHttpResponse, eliminating code duplication. Signed-off-by: 21pages --------- Signed-off-by: 21pages Co-authored-by: 21pages Co-authored-by: Claude Opus 4.6 --- src/common.rs | 570 ++++++++++++++++++++++++++++++++++++--- src/hbbs_http.rs | 10 +- src/hbbs_http/account.rs | 61 ++--- 3 files changed, 562 insertions(+), 79 deletions(-) diff --git a/src/common.rs b/src/common.rs index 3e23770c6d5..69e3ec3045d 100644 --- a/src/common.rs +++ b/src/common.rs @@ -39,7 +39,7 @@ use hbb_common::{ use crate::{ hbbs_http::{create_http_client_async, get_url_for_tls}, - ui_interface::{get_option, is_installed, set_option}, + ui_interface::{get_api_server as ui_get_api_server, get_option, is_installed, set_option}, }; #[derive(Debug, Eq, PartialEq)] @@ -1086,6 +1086,7 @@ fn get_api_server_(api: String, custom: String) -> String { #[inline] pub fn is_public(url: &str) -> bool { + let url = url.to_ascii_lowercase(); url.contains("rustdesk.com/") || url.ends_with("rustdesk.com") } @@ -1123,22 +1124,286 @@ pub fn get_audit_server(api: String, custom: String, typ: String) -> String { format!("{}/api/audit/{}", url, typ) } -pub async fn post_request(url: String, body: String, header: &str) -> ResultType { +/// Check if we should use raw TCP proxy for API calls. +/// Returns true if USE_RAW_TCP_FOR_API builtin option is "Y", WebSocket is off, +/// and the target URL belongs to the configured non-public API host. +#[inline] +fn should_use_raw_tcp_for_api(url: &str) -> bool { + get_builtin_option(keys::OPTION_USE_RAW_TCP_FOR_API) == "Y" + && !use_ws() + && is_tcp_proxy_api_target(url) +} + +/// Check if we can attempt raw TCP proxy fallback for this target URL. +#[inline] +fn can_fallback_to_raw_tcp(url: &str) -> bool { + !use_ws() && is_tcp_proxy_api_target(url) +} + +#[inline] +fn should_use_tcp_proxy_for_api_url(url: &str, api_url: &str) -> bool { + if api_url.is_empty() || is_public(api_url) { + return false; + } + + let target_host = url::Url::parse(url) + .ok() + .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase())); + let api_host = url::Url::parse(api_url) + .ok() + .and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase())); + + matches!((target_host, api_host), (Some(target), Some(api)) if target == api) +} + +#[inline] +fn is_tcp_proxy_api_target(url: &str) -> bool { + should_use_tcp_proxy_for_api_url(url, &ui_get_api_server()) +} + +fn tcp_proxy_log_target(url: &str) -> String { + url::Url::parse(url) + .ok() + .map(|parsed| { + let mut redacted = format!("{}://", parsed.scheme()); + let Some(host) = parsed.host() else { + return "".to_owned(); + }; + redacted.push_str(&host.to_string()); + if let Some(port) = parsed.port() { + redacted.push(':'); + redacted.push_str(&port.to_string()); + } + redacted.push_str(parsed.path()); + redacted + }) + .unwrap_or_else(|| "".to_owned()) +} + +#[inline] +fn get_tcp_proxy_addr() -> String { + check_port(Config::get_rendezvous_server(), RENDEZVOUS_PORT) +} + +/// Send an HTTP request via the rendezvous server's TCP proxy using protobuf. +/// Connects with `connect_tcp` + `secure_tcp`, sends `HttpProxyRequest`, +/// receives `HttpProxyResponse`. +/// +/// The entire operation (connect + handshake + send + receive) is wrapped in +/// an overall timeout of `CONNECT_TIMEOUT + READ_TIMEOUT` so that a stall at +/// any stage cannot block the caller indefinitely. +async fn tcp_proxy_request( + method: &str, + url: &str, + body: &[u8], + headers: Vec, +) -> ResultType { + let tcp_addr = get_tcp_proxy_addr(); + if tcp_addr.is_empty() { + bail!("No rendezvous server configured for TCP proxy"); + } + + let parsed = url::Url::parse(url)?; + let path = if let Some(query) = parsed.query() { + format!("{}?{}", parsed.path(), query) + } else { + parsed.path().to_string() + }; + + log::debug!( + "Sending {} {} via TCP proxy to {}", + method, + parsed.path(), + tcp_addr + ); + + let overall_timeout = CONNECT_TIMEOUT + READ_TIMEOUT; + timeout(overall_timeout, async { + let mut conn = socket_client::connect_tcp(&*tcp_addr, CONNECT_TIMEOUT).await?; + let key = crate::get_key(true).await; + secure_tcp_silent(&mut conn, &key).await?; + + let mut req = HttpProxyRequest::new(); + req.method = method.to_uppercase(); + req.path = path; + req.headers = headers.into(); + req.body = Bytes::from(body.to_vec()); + + let mut msg_out = RendezvousMessage::new(); + msg_out.set_http_proxy_request(req); + conn.send(&msg_out).await?; + + match conn.next().await { + Some(Ok(bytes)) => { + let msg_in = RendezvousMessage::parse_from_bytes(&bytes)?; + match msg_in.union { + Some(rendezvous_message::Union::HttpProxyResponse(resp)) => Ok(resp), + _ => bail!("Unexpected response from TCP proxy"), + } + } + Some(Err(e)) => bail!("TCP proxy read error: {}", e), + None => bail!("TCP proxy connection closed without response"), + } + }) + .await? +} + +/// Build HeaderEntry list from "Key: Value" style header string (used by post_request). +/// If the caller supplies a Content-Type header it overrides the default `application/json`. +fn parse_simple_header(header: &str) -> Vec { + let mut entries = Vec::new(); + let mut has_content_type = false; + if !header.is_empty() { + let tmp: Vec<&str> = header.splitn(2, ": ").collect(); + if tmp.len() == 2 { + if tmp[0].eq_ignore_ascii_case("Content-Type") { + has_content_type = true; + } + entries.push(HeaderEntry { + name: tmp[0].into(), + value: tmp[1].into(), + ..Default::default() + }); + } + } + if !has_content_type { + entries.insert( + 0, + HeaderEntry { + name: "Content-Type".into(), + value: "application/json".into(), + ..Default::default() + }, + ); + } + entries +} + +/// POST request via TCP proxy. +async fn post_request_via_tcp_proxy(url: &str, body: &str, header: &str) -> ResultType { + let headers = parse_simple_header(header); + let resp = tcp_proxy_request("POST", url, body.as_bytes(), headers).await?; + if !resp.error.is_empty() { + bail!("TCP proxy error: {}", resp.error); + } + Ok(String::from_utf8_lossy(&resp.body).to_string()) +} + +fn http_proxy_response_to_json(resp: HttpProxyResponse) -> ResultType { + if !resp.error.is_empty() { + bail!("TCP proxy error: {}", resp.error); + } + + let mut response_headers = Map::new(); + for entry in resp.headers.iter() { + response_headers.insert(entry.name.to_lowercase(), json!(entry.value)); + } + + let mut result = Map::new(); + result.insert("status_code".to_string(), json!(resp.status)); + result.insert("headers".to_string(), Value::Object(response_headers)); + result.insert( + "body".to_string(), + json!(String::from_utf8_lossy(&resp.body)), + ); + + serde_json::to_string(&result).map_err(|e| anyhow!("Failed to serialize response: {}", e)) +} + +fn parse_json_header_entries(header: &str) -> ResultType> { + let v: Value = serde_json::from_str(header)?; + if let Value::Object(obj) = v { + Ok(obj + .iter() + .map(|(key, value)| HeaderEntry { + name: key.clone(), + value: value.as_str().unwrap_or_default().into(), + ..Default::default() + }) + .collect()) + } else { + Err(anyhow!("HTTP header information parsing failed!")) + } +} + +/// Returns (status_code, body_text). Separating status so the wrapper can decide on fallback. +async fn post_request_http(url: &str, body: &str, header: &str) -> ResultType<(u16, String)> { let proxy_conf = Config::get_socks(); - let tls_url = get_url_for_tls(&url, &proxy_conf); + let tls_url = get_url_for_tls(url, &proxy_conf); let tls_type = get_cached_tls_type(tls_url); let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(tls_url); let response = post_request_( - &url, + url, tls_url, - body.clone(), + body.to_owned(), header, tls_type, danger_accept_invalid_cert, danger_accept_invalid_cert, ) .await?; - Ok(response.text().await?) + let status = response.status().as_u16(); + let text = response.text().await?; + Ok((status, text)) +} + +/// Try `http_fn` first; on connection failure or 5xx, fall back to `tcp_fn` +/// if the URL is eligible. 4xx responses are returned as-is. +async fn with_tcp_proxy_fallback( + url: &str, + method: &str, + http_fn: HttpFut, + tcp_fn: TcpFut, +) -> ResultType +where + HttpFut: Future>, + TcpFut: Future>, +{ + if should_use_raw_tcp_for_api(url) { + return tcp_fn.await; + } + + let http_result = http_fn.await; + let should_fallback = match &http_result { + Err(_) => true, + Ok((status, _)) => *status >= 500, + }; + + if should_fallback && can_fallback_to_raw_tcp(url) { + log::warn!( + "HTTP {} to {} failed or 5xx (result: {:?}), trying TCP proxy fallback", + method, + tcp_proxy_log_target(url), + http_result + .as_ref() + .map(|(s, _)| *s) + .map_err(|e| e.to_string()), + ); + match tcp_fn.await { + Ok(resp) => return Ok(resp), + Err(tcp_err) => { + log::warn!("TCP proxy fallback also failed: {:?}", tcp_err); + } + } + } + + http_result.map(|(_status, text)| text) +} + +/// POST request with raw TCP proxy support. +/// - If `USE_RAW_TCP_FOR_API` is "Y" and WS is off, goes directly through TCP proxy. +/// - Otherwise tries HTTP first; on connection failure or 5xx status, +/// falls back to TCP proxy if WS is off. +/// - 4xx responses are returned as-is (server is reachable, business logic error). +/// - If fallback also fails, returns the original HTTP result (text or error). +pub async fn post_request(url: String, body: String, header: &str) -> ResultType { + with_tcp_proxy_fallback( + &url, + "POST", + post_request_http(&url, &body, header), + post_request_via_tcp_proxy(&url, &body, header), + ) + .await } #[async_recursion] @@ -1246,21 +1511,16 @@ async fn get_http_response_async( tls_type.unwrap_or(TlsType::Rustls), danger_accept_invalid_cert.unwrap_or(false), ); - let mut http_client = match method { + let normalized_method = method.to_ascii_lowercase(); + let mut http_client = match normalized_method.as_str() { "get" => http_client.get(url), "post" => http_client.post(url), "put" => http_client.put(url), "delete" => http_client.delete(url), _ => return Err(anyhow!("The HTTP request method is not supported!")), }; - let v = serde_json::from_str(header)?; - - if let Value::Object(obj) = v { - for (key, value) in obj.iter() { - http_client = http_client.header(key, value.as_str().unwrap_or_default()); - } - } else { - return Err(anyhow!("HTTP header information parsing failed!")); + for entry in parse_json_header_entries(header)? { + http_client = http_client.header(entry.name, entry.value); } if tls_type.is_some() && danger_accept_invalid_cert.is_some() { @@ -1340,51 +1600,80 @@ async fn get_http_response_async( } } -#[tokio::main(flavor = "current_thread")] -pub async fn http_request_sync( - url: String, - method: String, +/// Returns (status_code, json_string) so the caller can inspect the status +/// without re-parsing the serialized JSON. +async fn http_request_http( + url: &str, + method: &str, body: Option, - header: String, -) -> ResultType { + header: &str, +) -> ResultType<(u16, String)> { let proxy_conf = Config::get_socks(); - let tls_url = get_url_for_tls(&url, &proxy_conf); + let tls_url = get_url_for_tls(url, &proxy_conf); let tls_type = get_cached_tls_type(tls_url); let danger_accept_invalid_cert = get_cached_tls_accept_invalid_cert(tls_url); let response = get_http_response_async( - &url, + url, tls_url, - &method, - body.clone(), - &header, + method, + body, + header, tls_type, danger_accept_invalid_cert, danger_accept_invalid_cert, ) .await?; // Serialize response headers - let mut response_headers = serde_json::map::Map::new(); + let mut response_headers = Map::new(); for (key, value) in response.headers() { - response_headers.insert( - key.to_string(), - serde_json::json!(value.to_str().unwrap_or("")), - ); + response_headers.insert(key.to_string(), json!(value.to_str().unwrap_or(""))); } let status_code = response.status().as_u16(); let response_body = response.text().await?; // Construct the JSON object - let mut result = serde_json::map::Map::new(); - result.insert("status_code".to_string(), serde_json::json!(status_code)); - result.insert( - "headers".to_string(), - serde_json::Value::Object(response_headers), - ); - result.insert("body".to_string(), serde_json::json!(response_body)); + let mut result = Map::new(); + result.insert("status_code".to_string(), json!(status_code)); + result.insert("headers".to_string(), Value::Object(response_headers)); + result.insert("body".to_string(), json!(response_body)); // Convert map to JSON string - serde_json::to_string(&result).map_err(|e| anyhow!("Failed to serialize response: {}", e)) + let json_str = serde_json::to_string(&result) + .map_err(|e| anyhow!("Failed to serialize response: {}", e))?; + Ok((status_code, json_str)) +} + +/// HTTP request with raw TCP proxy support. +#[tokio::main(flavor = "current_thread")] +pub async fn http_request_sync( + url: String, + method: String, + body: Option, + header: String, +) -> ResultType { + with_tcp_proxy_fallback( + &url, + &method, + http_request_http(&url, &method, body.clone(), &header), + http_request_via_tcp_proxy(&url, &method, body.as_deref(), &header), + ) + .await +} + +/// General HTTP request via TCP proxy. Header is a JSON string (used by http_request_sync). +/// Returns a JSON string with status_code, headers, body (same format as http_request_sync). +async fn http_request_via_tcp_proxy( + url: &str, + method: &str, + body: Option<&str>, + header: &str, +) -> ResultType { + let headers = parse_json_header_entries(header)?; + let body_bytes = body.unwrap_or("").as_bytes(); + + let resp = tcp_proxy_request(method, url, body_bytes, headers).await?; + http_proxy_response_to_json(resp) } #[inline] @@ -1647,7 +1936,7 @@ pub fn check_process(arg: &str, mut same_uid: bool) -> bool { false } -pub async fn secure_tcp(conn: &mut Stream, key: &str) -> ResultType<()> { +async fn secure_tcp_impl(conn: &mut Stream, key: &str, log_on_success: bool) -> ResultType<()> { // Skip additional encryption when using WebSocket connections (wss://) // as WebSocket Secure (wss://) already provides transport layer encryption. // This doesn't affect the end-to-end encryption between clients, @@ -1680,7 +1969,9 @@ pub async fn secure_tcp(conn: &mut Stream, key: &str) -> ResultType<()> { }); timeout(CONNECT_TIMEOUT, conn.send(&msg_out)).await??; conn.set_key(key); - log::info!("Connection secured"); + if log_on_success { + log::info!("Connection secured"); + } } _ => {} } @@ -1691,6 +1982,14 @@ pub async fn secure_tcp(conn: &mut Stream, key: &str) -> ResultType<()> { Ok(()) } +pub async fn secure_tcp(conn: &mut Stream, key: &str) -> ResultType<()> { + secure_tcp_impl(conn, key, true).await +} + +async fn secure_tcp_silent(conn: &mut Stream, key: &str) -> ResultType<()> { + secure_tcp_impl(conn, key, false).await +} + #[inline] fn get_pk(pk: &[u8]) -> Option<[u8; 32]> { if pk.len() == 32 { @@ -2468,11 +2767,13 @@ mod tests { assert!(is_public("https://rustdesk.com/")); assert!(is_public("https://www.rustdesk.com/")); assert!(is_public("https://api.rustdesk.com/v1")); + assert!(is_public("https://API.RUSTDESK.COM/v1")); assert!(is_public("https://rustdesk.com/path")); // Test URLs ending with "rustdesk.com" assert!(is_public("rustdesk.com")); assert!(is_public("https://rustdesk.com")); + assert!(is_public("https://RustDesk.com")); assert!(is_public("http://www.rustdesk.com")); assert!(is_public("https://api.rustdesk.com")); @@ -2485,6 +2786,193 @@ mod tests { assert!(!is_public("rustdesk.comhello.com")); } + #[test] + fn test_should_use_tcp_proxy_for_api_url() { + assert!(should_use_tcp_proxy_for_api_url( + "https://admin.example.com/api/login", + "https://admin.example.com" + )); + assert!(should_use_tcp_proxy_for_api_url( + "https://admin.example.com:21114/api/login", + "https://admin.example.com" + )); + assert!(!should_use_tcp_proxy_for_api_url( + "https://api.telegram.org/bot123/sendMessage", + "https://admin.example.com" + )); + assert!(!should_use_tcp_proxy_for_api_url( + "https://admin.rustdesk.com/api/login", + "https://admin.rustdesk.com" + )); + assert!(!should_use_tcp_proxy_for_api_url( + "https://admin.example.com/api/login", + "not a url" + )); + assert!(!should_use_tcp_proxy_for_api_url( + "not a url", + "https://admin.example.com" + )); + } + + #[test] + fn test_get_tcp_proxy_addr_normalizes_bare_ipv6_host() { + struct RestoreCustomRendezvousServer(String); + + impl Drop for RestoreCustomRendezvousServer { + fn drop(&mut self) { + Config::set_option( + keys::OPTION_CUSTOM_RENDEZVOUS_SERVER.to_string(), + self.0.clone(), + ); + } + } + + let _restore = RestoreCustomRendezvousServer(Config::get_option( + keys::OPTION_CUSTOM_RENDEZVOUS_SERVER, + )); + Config::set_option( + keys::OPTION_CUSTOM_RENDEZVOUS_SERVER.to_string(), + "1:2".to_string(), + ); + + assert_eq!(get_tcp_proxy_addr(), format!("[1:2]:{RENDEZVOUS_PORT}")); + } + + #[tokio::test] + async fn test_http_request_via_tcp_proxy_rejects_invalid_header_json() { + let result = http_request_via_tcp_proxy("not a url", "get", None, "{").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_http_request_via_tcp_proxy_rejects_non_object_header_json() { + let err = http_request_via_tcp_proxy("not a url", "get", None, "[]") + .await + .unwrap_err() + .to_string(); + assert!(err.contains("HTTP header information parsing failed!")); + } + + #[test] + fn test_parse_json_header_entries_preserves_single_content_type() { + let headers = parse_json_header_entries( + r#"{"Content-Type":"text/plain","Authorization":"Bearer token"}"#, + ) + .unwrap(); + + assert_eq!( + headers + .iter() + .filter(|entry| entry.name.eq_ignore_ascii_case("Content-Type")) + .count(), + 1 + ); + assert_eq!( + headers + .iter() + .find(|entry| entry.name.eq_ignore_ascii_case("Content-Type")) + .map(|entry| entry.value.as_str()), + Some("text/plain") + ); + } + + #[test] + fn test_parse_json_header_entries_does_not_add_default_content_type() { + let headers = parse_json_header_entries(r#"{"Authorization":"Bearer token"}"#).unwrap(); + + assert!(!headers + .iter() + .any(|entry| entry.name.eq_ignore_ascii_case("Content-Type"))); + } + + #[test] + fn test_parse_simple_header_respects_custom_content_type() { + let headers = parse_simple_header("Content-Type: text/plain"); + + assert_eq!( + headers + .iter() + .filter(|entry| entry.name.eq_ignore_ascii_case("Content-Type")) + .count(), + 1 + ); + assert_eq!( + headers + .iter() + .find(|entry| entry.name.eq_ignore_ascii_case("Content-Type")) + .map(|entry| entry.value.as_str()), + Some("text/plain") + ); + } + + #[test] + fn test_parse_simple_header_preserves_non_content_type_header() { + let headers = parse_simple_header("Authorization: Bearer token"); + + assert!(headers.iter().any(|entry| { + entry.name.eq_ignore_ascii_case("Authorization") + && entry.value.as_str() == "Bearer token" + })); + assert_eq!( + headers + .iter() + .filter(|entry| entry.name.eq_ignore_ascii_case("Content-Type")) + .count(), + 1 + ); + assert_eq!( + headers + .iter() + .find(|entry| entry.name.eq_ignore_ascii_case("Content-Type")) + .map(|entry| entry.value.as_str()), + Some("application/json") + ); + } + + #[test] + fn test_tcp_proxy_log_target_redacts_query_only() { + assert_eq!( + tcp_proxy_log_target("https://example.com/api/heartbeat?token=secret"), + "https://example.com/api/heartbeat" + ); + } + + #[test] + fn test_tcp_proxy_log_target_brackets_ipv6_host_with_port() { + assert_eq!( + tcp_proxy_log_target("https://[2001:db8::1]:21114/api/heartbeat?token=secret"), + "https://[2001:db8::1]:21114/api/heartbeat" + ); + } + + #[test] + fn test_http_proxy_response_to_json() { + let mut resp = HttpProxyResponse { + status: 200, + body: br#"{"ok":true}"#.to_vec().into(), + ..Default::default() + }; + resp.headers.push(HeaderEntry { + name: "Content-Type".into(), + value: "application/json".into(), + ..Default::default() + }); + + let json = http_proxy_response_to_json(resp).unwrap(); + let value: Value = serde_json::from_str(&json).unwrap(); + assert_eq!(value["status_code"], 200); + assert_eq!(value["headers"]["content-type"], "application/json"); + assert_eq!(value["body"], r#"{"ok":true}"#); + + let err = http_proxy_response_to_json(HttpProxyResponse { + error: "dial failed".into(), + ..Default::default() + }) + .unwrap_err() + .to_string(); + assert!(err.contains("TCP proxy error: dial failed")); + } + #[test] fn test_mouse_event_constants_and_mask_layout() { use super::input::*; diff --git a/src/hbbs_http.rs b/src/hbbs_http.rs index 20316b6f5f8..9e4538697a2 100644 --- a/src/hbbs_http.rs +++ b/src/hbbs_http.rs @@ -1,4 +1,4 @@ -use reqwest::blocking::Response; +use hbb_common::ResultType; use serde::de::DeserializeOwned; use serde_json::{Map, Value}; @@ -21,11 +21,9 @@ pub enum HbbHttpResponse { Data(T), } -impl TryFrom for HbbHttpResponse { - type Error = reqwest::Error; - - fn try_from(resp: Response) -> Result>::Error> { - let map = resp.json::>()?; +impl HbbHttpResponse { + pub fn parse(body: &str) -> ResultType { + let map = serde_json::from_str::>(body)?; if let Some(error) = map.get("error") { if let Some(err) = error.as_str() { Ok(Self::Error(err.to_owned())) diff --git a/src/hbbs_http/account.rs b/src/hbbs_http/account.rs index 8e614120062..3f824113b17 100644 --- a/src/hbbs_http/account.rs +++ b/src/hbbs_http/account.rs @@ -1,7 +1,6 @@ use super::HbbHttpResponse; use crate::hbbs_http::create_http_client_with_url; use hbb_common::{config::LocalConfig, log, ResultType}; -use reqwest::blocking::Client; use serde_derive::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::{ @@ -109,7 +108,7 @@ pub struct AuthBody { } pub struct OidcSession { - client: Option, + warmed_api_server: Option, state_msg: &'static str, failed_msg: String, code_url: Option, @@ -136,7 +135,7 @@ impl Default for UserStatus { impl OidcSession { fn new() -> Self { Self { - client: None, + warmed_api_server: None, state_msg: REQUESTING_ACCOUNT_AUTH, failed_msg: "".to_owned(), code_url: None, @@ -149,12 +148,13 @@ impl OidcSession { fn ensure_client(api_server: &str) { let mut write_guard = OIDC_SESSION.write().unwrap(); - if write_guard.client.is_none() { - // This URL is used to detect the appropriate TLS implementation for the server. - let login_option_url = format!("{}/api/login-options", &api_server); - let client = create_http_client_with_url(&login_option_url); - write_guard.client = Some(client); + if write_guard.warmed_api_server.as_deref() == Some(api_server) { + return; } + // This URL is used to detect the appropriate TLS implementation for the server. + let login_option_url = format!("{}/api/login-options", api_server); + let _ = create_http_client_with_url(&login_option_url); + write_guard.warmed_api_server = Some(api_server.to_owned()); } fn auth( @@ -164,26 +164,15 @@ impl OidcSession { uuid: &str, ) -> ResultType> { Self::ensure_client(api_server); - let resp = if let Some(client) = &OIDC_SESSION.read().unwrap().client { - client - .post(format!("{}/api/oidc/auth", api_server)) - .json(&serde_json::json!({ - "op": op, - "id": id, - "uuid": uuid, - "deviceInfo": crate::ui_interface::get_login_device_info(), - })) - .send()? - } else { - hbb_common::bail!("http client not initialized"); - }; - let status = resp.status(); - match resp.try_into() { - Ok(v) => Ok(v), - Err(err) => { - hbb_common::bail!("Http status: {}, err: {}", status, err); - } - } + let body = serde_json::json!({ + "op": op, + "id": id, + "uuid": uuid, + "deviceInfo": crate::ui_interface::get_login_device_info(), + }) + .to_string(); + let resp = crate::post_request_sync(format!("{}/api/oidc/auth", api_server), body, "")?; + HbbHttpResponse::parse(&resp) } fn query( @@ -197,11 +186,19 @@ impl OidcSession { &[("code", code), ("id", id), ("uuid", uuid)], )?; Self::ensure_client(api_server); - if let Some(client) = &OIDC_SESSION.read().unwrap().client { - Ok(client.get(url).send()?.try_into()?) - } else { - hbb_common::bail!("http client not initialized") + #[derive(Deserialize)] + struct HttpResponseBody { + body: String, } + + let resp = crate::http_request_sync( + url.to_string(), + "GET".to_owned(), + None, + "{}".to_owned(), + )?; + let resp = serde_json::from_str::(&resp)?; + HbbHttpResponse::parse(&resp.body) } fn reset(&mut self) {