Skip to content

Commit 41aee97

Browse files
authored
core: propagate ECH keys to the QUIC listener (#7670)
1 parent 441d5eb commit 41aee97

2 files changed

Lines changed: 72 additions & 1 deletion

File tree

listeners.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,10 @@ func (na NetworkAddress) ListenQUIC(ctx context.Context, portOffset uint, config
462462
sqs := newSharedQUICState(tlsConf)
463463
// http3.ConfigureTLSConfig only uses this field and tls App sets this field as well
464464
//nolint:gosec
465-
quicTlsConfig := &tls.Config{GetConfigForClient: sqs.getConfigForClient}
465+
quicTlsConfig := &tls.Config{
466+
GetConfigForClient: sqs.getConfigForClient,
467+
GetEncryptedClientHelloKeys: sqs.getEncryptedClientHelloKeys,
468+
}
466469
// Require clients to verify their source address when we're handling more than 1000 handshakes per second.
467470
// TODO: make tunable?
468471
limiter := rate.NewLimiter(1000, 1000)
@@ -540,6 +543,16 @@ func (sqs *sharedQUICState) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Co
540543
return sqs.activeTlsConf.GetConfigForClient(ch)
541544
}
542545

546+
// getEncryptedClientHelloKeys is used as tls.Config's GetEncryptedClientHelloKeys field.
547+
func (sqs *sharedQUICState) getEncryptedClientHelloKeys(ch *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
548+
sqs.rmu.RLock()
549+
defer sqs.rmu.RUnlock()
550+
if sqs.activeTlsConf.GetEncryptedClientHelloKeys == nil {
551+
return nil, nil
552+
}
553+
return sqs.activeTlsConf.GetEncryptedClientHelloKeys(ch)
554+
}
555+
543556
// addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc
544557
// so that when cancelled, the active tls.Config will change
545558
func (sqs *sharedQUICState) addState(tlsConfig *tls.Config) (context.Context, context.CancelCauseFunc) {

listeners_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package caddy
1616

1717
import (
18+
"crypto/tls"
1819
"reflect"
1920
"testing"
2021

@@ -175,6 +176,63 @@ func TestJoinNetworkAddress(t *testing.T) {
175176
}
176177
}
177178

179+
func TestSharedQUICStateGetEncryptedClientHelloKeys(t *testing.T) {
180+
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
181+
initialKeys := []tls.EncryptedClientHelloKey{{Config: []byte("initial"), PrivateKey: []byte("initial-key")}}
182+
updatedKeys := []tls.EncryptedClientHelloKey{{Config: []byte("updated"), PrivateKey: []byte("updated-key")}}
183+
184+
initialConfig := &tls.Config{
185+
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
186+
return nil, nil
187+
},
188+
GetEncryptedClientHelloKeys: func(*tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
189+
return initialKeys, nil
190+
},
191+
}
192+
193+
sqs := newSharedQUICState(initialConfig)
194+
195+
keys, err := sqs.getEncryptedClientHelloKeys(hello)
196+
if err != nil {
197+
t.Fatalf("getting initial ECH keys: %v", err)
198+
}
199+
if !reflect.DeepEqual(keys, initialKeys) {
200+
t.Fatalf("unexpected initial ECH keys: got %#v, want %#v", keys, initialKeys)
201+
}
202+
203+
updatedConfig := &tls.Config{
204+
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
205+
return nil, nil
206+
},
207+
GetEncryptedClientHelloKeys: func(*tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
208+
return updatedKeys, nil
209+
},
210+
}
211+
212+
_, cancel := sqs.addState(updatedConfig)
213+
sqs.rmu.Lock()
214+
sqs.activeTlsConf = updatedConfig
215+
sqs.rmu.Unlock()
216+
217+
keys, err = sqs.getEncryptedClientHelloKeys(hello)
218+
if err != nil {
219+
t.Fatalf("getting updated ECH keys: %v", err)
220+
}
221+
if !reflect.DeepEqual(keys, updatedKeys) {
222+
t.Fatalf("unexpected updated ECH keys: got %#v, want %#v", keys, updatedKeys)
223+
}
224+
225+
cancel(nil)
226+
227+
keys, err = sqs.getEncryptedClientHelloKeys(hello)
228+
if err != nil {
229+
t.Fatalf("getting restored ECH keys: %v", err)
230+
}
231+
if !reflect.DeepEqual(keys, initialKeys) {
232+
t.Fatalf("unexpected restored ECH keys: got %#v, want %#v", keys, initialKeys)
233+
}
234+
}
235+
178236
func TestParseNetworkAddress(t *testing.T) {
179237
for i, tc := range []struct {
180238
input string

0 commit comments

Comments
 (0)