@@ -3,12 +3,60 @@ use std::sync::Arc;
33use tokio:: io:: { AsyncReadExt , AsyncWriteExt } ;
44use tokio:: net:: { TcpListener , TcpStream } ;
55use tokio:: sync:: Mutex ;
6- use tokio_rustls:: TlsAcceptor ;
6+ use tokio_rustls:: rustls:: client:: danger:: {
7+ HandshakeSignatureValid , ServerCertVerified , ServerCertVerifier ,
8+ } ;
9+ use tokio_rustls:: rustls:: pki_types:: { CertificateDer , ServerName , UnixTime } ;
10+ use tokio_rustls:: rustls:: { ClientConfig , DigitallySignedStruct , SignatureScheme } ;
11+ use tokio_rustls:: { TlsAcceptor , TlsConnector } ;
712
813use crate :: config:: Config ;
914use crate :: domain_fronter:: DomainFronter ;
1015use crate :: mitm:: MitmCertManager ;
1116
17+ const SNI_REWRITE_SUFFIXES : & [ & str ] = & [
18+ "youtube.com" ,
19+ "youtu.be" ,
20+ "youtube-nocookie.com" ,
21+ "youtubeeducation.com" ,
22+ "googlevideo.com" ,
23+ "ytimg.com" ,
24+ "ggpht.com" ,
25+ "gvt1.com" ,
26+ "gvt2.com" ,
27+ "doubleclick.net" ,
28+ "googlesyndication.com" ,
29+ "googleadservices.com" ,
30+ "google-analytics.com" ,
31+ "googletagmanager.com" ,
32+ "googletagservices.com" ,
33+ "fonts.googleapis.com" ,
34+ ] ;
35+
36+ fn matches_sni_rewrite ( host : & str ) -> bool {
37+ let h = host. to_ascii_lowercase ( ) ;
38+ let h = h. trim_end_matches ( '.' ) ;
39+ SNI_REWRITE_SUFFIXES
40+ . iter ( )
41+ . any ( |s| h == * s || h. ends_with ( & format ! ( ".{}" , s) ) )
42+ }
43+
44+ fn hosts_override < ' a > ( hosts : & ' a std:: collections:: HashMap < String , String > , host : & str ) -> Option < & ' a str > {
45+ let h = host. to_ascii_lowercase ( ) ;
46+ let h = h. trim_end_matches ( '.' ) ;
47+ if let Some ( ip) = hosts. get ( h) {
48+ return Some ( ip. as_str ( ) ) ;
49+ }
50+ let parts: Vec < & str > = h. split ( '.' ) . collect ( ) ;
51+ for i in 1 ..parts. len ( ) {
52+ let parent = parts[ i..] . join ( "." ) ;
53+ if let Some ( ip) = hosts. get ( & parent) {
54+ return Some ( ip. as_str ( ) ) ;
55+ }
56+ }
57+ None
58+ }
59+
1260#[ derive( Debug , thiserror:: Error ) ]
1361pub enum ProxyError {
1462 #[ error( "io: {0}" ) ]
@@ -20,17 +68,48 @@ pub struct ProxyServer {
2068 port : u16 ,
2169 fronter : Arc < DomainFronter > ,
2270 mitm : Arc < Mutex < MitmCertManager > > ,
71+ rewrite_ctx : Arc < RewriteCtx > ,
72+ }
73+
74+ pub struct RewriteCtx {
75+ pub google_ip : String ,
76+ pub front_domain : String ,
77+ pub hosts : std:: collections:: HashMap < String , String > ,
78+ pub tls_connector : TlsConnector ,
2379}
2480
2581impl ProxyServer {
2682 pub fn new ( config : & Config , mitm : Arc < Mutex < MitmCertManager > > ) -> Result < Self , ProxyError > {
2783 let fronter = DomainFronter :: new ( config)
2884 . map_err ( |e| std:: io:: Error :: new ( std:: io:: ErrorKind :: Other , format ! ( "{e}" ) ) ) ?;
85+
86+ let tls_config = if config. verify_ssl {
87+ let mut roots = tokio_rustls:: rustls:: RootCertStore :: empty ( ) ;
88+ roots. extend ( webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) ) ;
89+ ClientConfig :: builder ( )
90+ . with_root_certificates ( roots)
91+ . with_no_client_auth ( )
92+ } else {
93+ ClientConfig :: builder ( )
94+ . dangerous ( )
95+ . with_custom_certificate_verifier ( Arc :: new ( NoVerify ) )
96+ . with_no_client_auth ( )
97+ } ;
98+ let tls_connector = TlsConnector :: from ( Arc :: new ( tls_config) ) ;
99+
100+ let rewrite_ctx = Arc :: new ( RewriteCtx {
101+ google_ip : config. google_ip . clone ( ) ,
102+ front_domain : config. front_domain . clone ( ) ,
103+ hosts : config. hosts . clone ( ) ,
104+ tls_connector,
105+ } ) ;
106+
29107 Ok ( Self {
30108 host : config. listen_host . clone ( ) ,
31109 port : config. listen_port ,
32110 fronter : Arc :: new ( fronter) ,
33111 mitm,
112+ rewrite_ctx,
34113 } )
35114 }
36115
@@ -53,8 +132,9 @@ impl ProxyServer {
53132 let _ = sock. set_nodelay ( true ) ;
54133 let fronter = self . fronter . clone ( ) ;
55134 let mitm = self . mitm . clone ( ) ;
135+ let rewrite_ctx = self . rewrite_ctx . clone ( ) ;
56136 tokio:: spawn ( async move {
57- if let Err ( e) = handle_client ( sock, fronter, mitm) . await {
137+ if let Err ( e) = handle_client ( sock, fronter, mitm, rewrite_ctx ) . await {
58138 tracing:: debug!( "client {} closed: {}" , peer, e) ;
59139 }
60140 } ) ;
@@ -66,6 +146,7 @@ async fn handle_client(
66146 mut sock : TcpStream ,
67147 fronter : Arc < DomainFronter > ,
68148 mitm : Arc < Mutex < MitmCertManager > > ,
149+ rewrite_ctx : Arc < RewriteCtx > ,
69150) -> std:: io:: Result < ( ) > {
70151 // Read the first request (head only).
71152 let ( head, leftover) = match read_http_head ( & mut sock) . await ? {
@@ -77,7 +158,12 @@ async fn handle_client(
77158 . ok_or_else ( || std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , "bad request" ) ) ?;
78159
79160 if method. eq_ignore_ascii_case ( "CONNECT" ) {
80- do_connect ( sock, & target, fronter, mitm) . await
161+ let ( host, port) = parse_host_port ( & target) ;
162+ if matches_sni_rewrite ( & host) || hosts_override ( & rewrite_ctx. hosts , & host) . is_some ( ) {
163+ do_sni_rewrite_connect ( sock, & host, port, mitm, rewrite_ctx) . await
164+ } else {
165+ do_connect ( sock, & target, fronter, mitm) . await
166+ }
81167 } else {
82168 do_plain_http ( sock, & head, & leftover, fronter) . await
83169 }
@@ -189,6 +275,142 @@ async fn do_connect(
189275 Ok ( ( ) )
190276}
191277
278+ async fn do_sni_rewrite_connect (
279+ mut sock : TcpStream ,
280+ host : & str ,
281+ port : u16 ,
282+ mitm : Arc < Mutex < MitmCertManager > > ,
283+ rewrite_ctx : Arc < RewriteCtx > ,
284+ ) -> std:: io:: Result < ( ) > {
285+ sock. write_all ( b"HTTP/1.1 200 Connection Established\r \n \r \n " ) . await ?;
286+ sock. flush ( ) . await ?;
287+
288+ let target_ip = hosts_override ( & rewrite_ctx. hosts , host)
289+ . map ( |s| s. to_string ( ) )
290+ . unwrap_or_else ( || rewrite_ctx. google_ip . clone ( ) ) ;
291+
292+ tracing:: info!(
293+ "SNI-rewrite tunnel -> {}:{} via {} (outbound SNI={})" ,
294+ host, port, target_ip, rewrite_ctx. front_domain
295+ ) ;
296+
297+ // Accept browser TLS with a cert we sign for `host`.
298+ let server_config = {
299+ let mut m = mitm. lock ( ) . await ;
300+ match m. get_server_config ( host) {
301+ Ok ( c) => c,
302+ Err ( e) => {
303+ tracing:: error!( "cert gen failed for {}: {}" , host, e) ;
304+ return Ok ( ( ) ) ;
305+ }
306+ }
307+ } ;
308+ let inbound = match TlsAcceptor :: from ( server_config) . accept ( sock) . await {
309+ Ok ( t) => t,
310+ Err ( e) => {
311+ tracing:: debug!( "inbound TLS accept failed for {}: {}" , host, e) ;
312+ return Ok ( ( ) ) ;
313+ }
314+ } ;
315+
316+ // Open outbound TLS to google_ip with SNI=front_domain.
317+ let upstream_tcp = match tokio:: time:: timeout (
318+ std:: time:: Duration :: from_secs ( 10 ) ,
319+ TcpStream :: connect ( ( target_ip. as_str ( ) , port) ) ,
320+ )
321+ . await
322+ {
323+ Ok ( Ok ( s) ) => s,
324+ Ok ( Err ( e) ) => {
325+ tracing:: debug!( "upstream connect failed for {}: {}" , host, e) ;
326+ return Ok ( ( ) ) ;
327+ }
328+ Err ( _) => {
329+ tracing:: debug!( "upstream connect timeout for {}" , host) ;
330+ return Ok ( ( ) ) ;
331+ }
332+ } ;
333+ let _ = upstream_tcp. set_nodelay ( true ) ;
334+
335+ let server_name = match ServerName :: try_from ( rewrite_ctx. front_domain . clone ( ) ) {
336+ Ok ( n) => n,
337+ Err ( e) => {
338+ tracing:: error!( "invalid front_domain '{}': {}" , rewrite_ctx. front_domain, e) ;
339+ return Ok ( ( ) ) ;
340+ }
341+ } ;
342+ let outbound = match rewrite_ctx
343+ . tls_connector
344+ . connect ( server_name, upstream_tcp)
345+ . await
346+ {
347+ Ok ( t) => t,
348+ Err ( e) => {
349+ tracing:: debug!( "outbound TLS connect failed for {}: {}" , host, e) ;
350+ return Ok ( ( ) ) ;
351+ }
352+ } ;
353+
354+ // Bridge decrypted bytes between the two TLS streams.
355+ let ( mut ir, mut iw) = tokio:: io:: split ( inbound) ;
356+ let ( mut or, mut ow) = tokio:: io:: split ( outbound) ;
357+ let client_to_server = async { tokio:: io:: copy ( & mut ir, & mut ow) . await } ;
358+ let server_to_client = async { tokio:: io:: copy ( & mut or, & mut iw) . await } ;
359+ tokio:: select! {
360+ _ = client_to_server => { }
361+ _ = server_to_client => { }
362+ }
363+ Ok ( ( ) )
364+ }
365+
366+ #[ derive( Debug ) ]
367+ struct NoVerify ;
368+
369+ impl ServerCertVerifier for NoVerify {
370+ fn verify_server_cert (
371+ & self ,
372+ _end_entity : & CertificateDer < ' _ > ,
373+ _intermediates : & [ CertificateDer < ' _ > ] ,
374+ _server_name : & ServerName < ' _ > ,
375+ _ocsp_response : & [ u8 ] ,
376+ _now : UnixTime ,
377+ ) -> Result < ServerCertVerified , tokio_rustls:: rustls:: Error > {
378+ Ok ( ServerCertVerified :: assertion ( ) )
379+ }
380+
381+ fn verify_tls12_signature (
382+ & self ,
383+ _message : & [ u8 ] ,
384+ _cert : & CertificateDer < ' _ > ,
385+ _dss : & DigitallySignedStruct ,
386+ ) -> Result < HandshakeSignatureValid , tokio_rustls:: rustls:: Error > {
387+ Ok ( HandshakeSignatureValid :: assertion ( ) )
388+ }
389+
390+ fn verify_tls13_signature (
391+ & self ,
392+ _message : & [ u8 ] ,
393+ _cert : & CertificateDer < ' _ > ,
394+ _dss : & DigitallySignedStruct ,
395+ ) -> Result < HandshakeSignatureValid , tokio_rustls:: rustls:: Error > {
396+ Ok ( HandshakeSignatureValid :: assertion ( ) )
397+ }
398+
399+ fn supported_verify_schemes ( & self ) -> Vec < SignatureScheme > {
400+ vec ! [
401+ SignatureScheme :: RSA_PKCS1_SHA256 ,
402+ SignatureScheme :: RSA_PKCS1_SHA384 ,
403+ SignatureScheme :: RSA_PKCS1_SHA512 ,
404+ SignatureScheme :: ECDSA_NISTP256_SHA256 ,
405+ SignatureScheme :: ECDSA_NISTP384_SHA384 ,
406+ SignatureScheme :: RSA_PSS_SHA256 ,
407+ SignatureScheme :: RSA_PSS_SHA384 ,
408+ SignatureScheme :: RSA_PSS_SHA512 ,
409+ SignatureScheme :: ED25519 ,
410+ ]
411+ }
412+ }
413+
192414fn parse_host_port ( target : & str ) -> ( String , u16 ) {
193415 if let Some ( ( h, p) ) = target. rsplit_once ( ':' ) {
194416 let port: u16 = p. parse ( ) . unwrap_or ( 443 ) ;
0 commit comments