@@ -9,13 +9,11 @@ package dns53
99import (
1010 "context"
1111 "crypto/tls"
12- "errors"
1312 "fmt"
1413 "net"
1514 "net/netip"
1615 "net/url"
1716 "strconv"
18- "sync/atomic"
1917 "time"
2018
2119 x "github.com/celzero/firestack/intra/backend"
@@ -31,6 +29,8 @@ import (
3129
3230const usepool = true
3331
32+ const echRetryPeriod = 8 * time .Hour
33+
3434type dot struct {
3535 ctx context.Context
3636 done context.CancelFunc
@@ -43,14 +43,16 @@ type dot struct {
4343 host string // hostname from the url
4444
4545 c * dns.Client
46- c3 * dns.Client // with ech
4746 proxies ipn.ProxyProvider // may be nil
4847 relay string // may be empty
4948 skipTLSVerify bool
5049
5150 pool * core.MultConnPool [uintptr ]
5251 usepool bool
5352
53+ echconfig * core.Volatile [* tls.Config ] // echconfig for the endpoint; may be nil
54+ echlastattempt * core.Volatile [time.Time ] // last attempt fetching ech cfg
55+
5456 est core.P2QuantileEstimator
5557 status * core.Volatile [int ]
5658}
@@ -63,10 +65,7 @@ func NewTLSTransport(ctx context.Context, id, rawurl string, addrs []string, px
6365 MinVersion : tls .VersionTLS12 ,
6466 SessionTicketsDisabled : false ,
6567 }
66- echcfg := & tls.Config {
67- MinVersion : tls .VersionTLS13 ,
68- SessionTicketsDisabled : false ,
69- }
68+
7069 // rawurl is either tls:host[:port] or tls://host[:port] or host[:port]
7170 parsedurl , err := url .Parse (rawurl )
7271 if err != nil {
@@ -76,7 +75,6 @@ func NewTLSTransport(ctx context.Context, id, rawurl string, addrs []string, px
7675 if parsedurl .Scheme != "tls" {
7776 log .I ("dot: disabling tls verification for %s" , rawurl )
7877 tlscfg .InsecureSkipVerify = true
79- echcfg .InsecureSkipVerify = true
8078 skipTLSVerify = true
8179 }
8280 var relay string
@@ -97,32 +95,28 @@ func NewTLSTransport(ctx context.Context, id, rawurl string, addrs []string, px
9795 tlscfg .ClientSessionCache = core .TlsSessionCache ()
9896 addrport , port := url2addrport (rawurl )
9997 t = & dot {
100- ctx : ctx ,
101- done : done ,
102- id : id ,
103- url : rawurl ,
104- host : hostname ,
105- skipTLSVerify : skipTLSVerify ,
106- addrport : addrport , // may or may not be ipaddr
107- port : port ,
108- status : core .NewVolatile (x .Start ),
109- proxies : px ,
110- relay : relay ,
111- pool : core.NewMultConnPool [uintptr ](ctx ),
112- usepool : usepool ,
113- est : core .NewP50Estimator (ctx ),
114- }
115- ech := t .ech ()
116- if len (ech ) > 0 {
117- echcfg .ClientSessionCache = core .TlsSessionCache ()
118- echcfg .EncryptedClientHelloConfigList = ech
119- echcfg .EncryptedClientHelloRejectionVerify = t .echVerifyFn ()
120- t .c3 = dnsclient (echcfg )
121- }
98+ ctx : ctx ,
99+ done : done ,
100+ id : id ,
101+ url : rawurl ,
102+ host : hostname ,
103+ skipTLSVerify : skipTLSVerify ,
104+ addrport : addrport , // may or may not be ipaddr
105+ port : port ,
106+ status : core .NewVolatile (x .Start ),
107+ proxies : px ,
108+ relay : relay ,
109+ pool : core.NewMultConnPool [uintptr ](ctx ),
110+ usepool : usepool ,
111+ est : core .NewP50Estimator (ctx ),
112+ echconfig : core .NewZeroVolatile [* tls.Config ](),
113+ echlastattempt : core .NewZeroVolatile [time.Time ](),
114+ }
115+ echcfg := t .getOrCreateEchConfigIfNeeded ()
122116 // local dialer: protect.MakeNsDialer(id, ctl)
123117 t .c = dnsclient (tlscfg )
124118 log .I ("dot: (%s) setup: %s; relay? %t; resolved? %t, ech? %t" ,
125- id , rawurl , len (relay ) > 0 , ok , len ( ech ) > 0 )
119+ id , rawurl , len (relay ) > 0 , ok , echcfg != nil )
126120 return t , nil
127121}
128122
@@ -131,8 +125,8 @@ func dnsclient(c *tls.Config) *dns.Client {
131125 Net : "tcp-tls" ,
132126 Dialer : nil , // unused; dialers from px take precedence
133127 Timeout : dottimeout ,
134- SingleInflight : true , // coalsece queries
135- TLSConfig : c .Clone (),
128+ SingleInflight : true , // coalsece queries
129+ TLSConfig : c .Clone (), // may be left unused
136130 }
137131}
138132
@@ -156,7 +150,7 @@ func (t *dot) echVerifyFn() func(tls.ConnectionState) error {
156150 return nil // delegate to stdlib
157151}
158152
159- func (t * dot ) doQuery (pid string , q * dns.Msg ) (response * dns.Msg , rpid string , elapsed time.Duration , qerr * dnsx.QueryError ) {
153+ func (t * dot ) doQuery (pid string , q * dns.Msg ) (response * dns.Msg , rpid string , ech bool , elapsed time.Duration , qerr * dnsx.QueryError ) {
160154 if q == nil || ! xdns .HasAnyQuestion (q ) {
161155 qerr = dnsx .NewBadQueryError (fmt .Errorf ("err len(query) %d" , xdns .Len (q )))
162156 return
@@ -167,15 +161,15 @@ func (t *dot) doQuery(pid string, q *dns.Msg) (response *dns.Msg, rpid string, e
167161 return
168162 }
169163
170- response , rpid , elapsed , qerr = t .sendRequest (pid , q )
164+ response , rpid , ech , elapsed , qerr = t .sendRequest (pid , q )
171165
172166 if qerr != nil { // only on send-request errors
173167 response = xdns .Servfail (q )
174168 }
175169 return
176170}
177171
178- func (t * dot ) tlsdial (p ipn.Proxy ) (dc * dns.Conn , who uintptr , err error ) {
172+ func (t * dot ) tlsdial (p ipn.Proxy ) (dc * dns.Conn , who uintptr , usingech bool , err error ) {
179173 who = p .Handle ()
180174
181175 defer func () {
@@ -189,52 +183,58 @@ func (t *dot) tlsdial(p ipn.Proxy) (dc *dns.Conn, who uintptr, err error) {
189183 }()
190184
191185 if dc = t .fromPool (who ); dc != nil {
192- return
186+ return dc , who , false , nil // pooled connections don't track ECH state
193187 }
194188
195- var usingech bool
196189 var c net.Conn = nil // dot is always tcp
197190 addr := t .addrport // t.addr may be ip or hostname
198- if t .c3 != nil { // may be nil if ech is not available
199- cfg := t .c3 .TLSConfig // don't clone; may be modified by dialers.DialWithTls
200- c , err = dialers .DialWithTls (p .Dialer (), cfg , "tcp" , addr )
201- usingech = true
191+
192+ // Try ECH first if available
193+ if echcfg := t .getOrCreateEchConfigIfNeeded (); echcfg != nil {
194+ // update ech config which may have been changed by DialWithTls
195+ defer t .echconfig .Store (echcfg )
196+ c , err = dialers .DialWithTls (p .Dialer (), echcfg , "tcp" , addr )
202197 }
198+
203199 if c == nil && core .IsNil (c ) { // no ech or ech failed
204200 cfg := t .c .TLSConfig
205201 c , err = dialers .DialWithTls (p .Dialer (), cfg , "tcp" , addr )
202+ usingech = false
206203 }
207204 if c != nil && core .IsNotNil (c ) {
208- return & dns.Conn {Conn : c }, who , err
205+ if tlsConn , ok := c .(* tls.Conn ); ok && usingech {
206+ usingech = tlsConn .ConnectionState ().ECHAccepted
207+ }
208+ return & dns.Conn {Conn : c }, who , usingech , err
209209 } else {
210210 err = core .OneErr (err , errNoNet )
211211 log .W ("dot: tlsdial: (%s) nil conn/err for %s, ech? %t; err? %v" ,
212212 t .id , addr , usingech , err )
213213 }
214- return nil , who , err
214+ return nil , who , false , err
215215}
216216
217- func (t * dot ) pxdial (pid string ) (* dns.Conn , string , uintptr , error ) {
217+ func (t * dot ) pxdial (pid string ) (* dns.Conn , string , uintptr , bool , error ) {
218218 var px ipn.Proxy
219219 if len (t .relay ) > 0 { // relay takes precedence
220220 pid = t .relay
221221 }
222222 if t .proxies != nil { // err if t.proxies is nil
223223 var err error
224224 if px , err = t .proxies .ProxyFor (pid ); err != nil {
225- return nil , "" , core .Nobody , err
225+ return nil , "" , core .Nobody , false , err
226226 }
227227 }
228228 if px == nil {
229- return nil , "" , core .Nobody , dnsx .ErrNoProxyProvider
229+ return nil , "" , core .Nobody , false , dnsx .ErrNoProxyProvider
230230 }
231231 pid = px .ID ().V ()
232232 rpid := ipn .ViaID (px )
233233 log .V ("dot: pxdial: (%s) using relay/proxy %s (via: %s) at %s" ,
234234 t .id , pid , rpid , px .GetAddr ())
235235
236- c , who , err := t .tlsdial (px )
237- return c , rpid , who , err
236+ c , who , ech , err := t .tlsdial (px )
237+ return c , rpid , who , ech , err
238238}
239239
240240// toPool takes ownership of c.
@@ -269,7 +269,7 @@ func clos(c net.Conn) {
269269 core .CloseConn (c )
270270}
271271
272- func (t * dot ) sendRequest (pid string , q * dns.Msg ) (ans * dns.Msg , rpid string , elapsed time.Duration , qerr * dnsx.QueryError ) {
272+ func (t * dot ) sendRequest (pid string , q * dns.Msg ) (ans * dns.Msg , rpid string , ech bool , elapsed time.Duration , qerr * dnsx.QueryError ) {
273273 var err error
274274
275275 if q == nil || ! xdns .HasAnyQuestion (q ) {
@@ -282,12 +282,13 @@ func (t *dot) sendRequest(pid string, q *dns.Msg) (ans *dns.Msg, rpid string, el
282282 userelay := len (t .relay ) > 0
283283 useproxy := len (pid ) != 0 // pid == dnsx.NetNoProxy => ipn.Block
284284 if useproxy || userelay { // ref dns.Client.Dial
285- conn , rpid , who , err = t .pxdial (pid )
285+ conn , rpid , who , ech , err = t .pxdial (pid )
286286 } else {
287287 err = dnsx .ErrNoProxyProvider
288288 }
289289
290290 if err == nil {
291+ // tls config is not used with this exchange as conn is pre-supplied
291292 ans , elapsed , err = t .c .ExchangeWithConnContext (t .ctx , q , conn )
292293 } // fallthrough
293294
@@ -316,6 +317,7 @@ func (t *dot) Query(network string, q *dns.Msg, smm *x.DNSSummary) (ans *dns.Msg
316317 var qerr * dnsx.QueryError
317318 var elapsed time.Duration
318319 var pid , rpid string
320+ var ech bool
319321
320322 if r := t .relay ; len (r ) > 0 {
321323 pid = t .chooseProxy (r )
@@ -324,7 +326,7 @@ func (t *dot) Query(network string, q *dns.Msg, smm *x.DNSSummary) (ans *dns.Msg
324326 pid = t .chooseProxy (pids ... )
325327 }
326328
327- ans , rpid , elapsed , qerr = t .doQuery (pid , q )
329+ ans , rpid , ech , elapsed , qerr = t .doQuery (pid , q )
328330
329331 status := dnsx .Complete
330332 if qerr != nil {
@@ -339,6 +341,9 @@ func (t *dot) Query(network string, q *dns.Msg, smm *x.DNSSummary) (ans *dns.Msg
339341 smm .RCode = xdns .Rcode (ans )
340342 smm .RTtl = xdns .RTtl (ans )
341343 smm .Server = t .getAddr ()
344+ if ech {
345+ smm .Server = dnsx .EchPrefix + smm .Server
346+ }
342347 smm .PID = pid // may be local dnsx.IsLocalProxy
343348 smm .RPID = rpid // may be empty
344349 if err != nil {
@@ -347,8 +352,8 @@ func (t *dot) Query(network string, q *dns.Msg, smm *x.DNSSummary) (ans *dns.Msg
347352 smm .Status = status
348353 t .est .Add (smm .Latency )
349354
350- log .V ("dot: len(res): fro %s:%d a:%d/sz:%d/pad:%d, data: %s, via: %s, err? %v" ,
351- smm .QName , smm .QType , xdns .Len (ans ), xdns .Size (ans ), xdns .EDNS0PadLen (ans ), smm .RData , smm .PID , err )
355+ log .V ("dot: %s ech? %t; len(res): fro %s:%d a:%d/sz:%d/pad:%d, data: %s / status: %d , via: %s, err? %v" ,
356+ t . id , ech , smm .QName , smm .QType , xdns .Len (ans ), xdns .Size (ans ), xdns .EDNS0PadLen (ans ), smm .RData , smm . Status , smm .PID , err )
352357
353358 return
354359}
@@ -378,7 +383,7 @@ func (t *dot) GetRelay() x.Proxy {
378383}
379384
380385func (t * dot ) getAddr () (addr string ) {
381- if t .c3 != nil {
386+ if t .echconfig . Load () != nil {
382387 addr = dnsx .EchPrefix + t .addrport
383388 } else if t .skipTLSVerify {
384389 addr = dnsx .NoPkiPrefix + t .addrport
@@ -426,9 +431,34 @@ func url2addrport(url string) (string, uint16) {
426431 return url , port
427432}
428433
429- func logwif (cond bool ) log.LogFn {
430- if cond {
431- return log .W
434+ func (t * dot ) getOrCreateEchConfigIfNeeded () * tls.Config {
435+ echcfg := t .echconfig .Load ()
436+ if echcfg != nil {
437+ return echcfg
438+ }
439+
440+ prev := t .echlastattempt .Load ()
441+ if time .Since (prev ) < echRetryPeriod {
442+ return nil
443+ }
444+ refetch := t .echlastattempt .Cas (prev , time .Now ())
445+ if ! refetch {
446+ return nil
432447 }
433- return log .V
448+
449+ if ech := t .ech (); len (ech ) > 0 {
450+ echcfg = & tls.Config {
451+ InsecureSkipVerify : t .skipTLSVerify ,
452+ MinVersion : tls .VersionTLS13 , // must be 1.3
453+ EncryptedClientHelloConfigList : ech ,
454+ SessionTicketsDisabled : false ,
455+ ClientSessionCache : core .TlsSessionCache (),
456+ EncryptedClientHelloRejectionVerify : t .echVerifyFn (),
457+ }
458+ t .echconfig .Store (echcfg )
459+ }
460+
461+ ok := echcfg != nil
462+ logwif (! ok )("dot: %s fetch ech... ok? %t" , t .id , ok )
463+ return echcfg
434464}
0 commit comments