|
15 | 15 | package caddy |
16 | 16 |
|
17 | 17 | import ( |
| 18 | + "crypto/tls" |
18 | 19 | "reflect" |
19 | 20 | "testing" |
20 | 21 |
|
@@ -175,6 +176,63 @@ func TestJoinNetworkAddress(t *testing.T) { |
175 | 176 | } |
176 | 177 | } |
177 | 178 |
|
| 179 | +func TestSharedQUICStateGetEncryptedClientHelloKeys(t *testing.T) { |
| 180 | + hello := &tls.ClientHelloInfo{ServerName: "example.com"} |
| 181 | + initialKeys := []tls.EncryptedClientHelloKey{{Config: []byte("initial"), PrivateKey: []byte("initial-key")}} |
| 182 | + updatedKeys := []tls.EncryptedClientHelloKey{{Config: []byte("updated"), PrivateKey: []byte("updated-key")}} |
| 183 | + |
| 184 | + initialConfig := &tls.Config{ |
| 185 | + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { |
| 186 | + return nil, nil |
| 187 | + }, |
| 188 | + GetEncryptedClientHelloKeys: func(*tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) { |
| 189 | + return initialKeys, nil |
| 190 | + }, |
| 191 | + } |
| 192 | + |
| 193 | + sqs := newSharedQUICState(initialConfig) |
| 194 | + |
| 195 | + keys, err := sqs.getEncryptedClientHelloKeys(hello) |
| 196 | + if err != nil { |
| 197 | + t.Fatalf("getting initial ECH keys: %v", err) |
| 198 | + } |
| 199 | + if !reflect.DeepEqual(keys, initialKeys) { |
| 200 | + t.Fatalf("unexpected initial ECH keys: got %#v, want %#v", keys, initialKeys) |
| 201 | + } |
| 202 | + |
| 203 | + updatedConfig := &tls.Config{ |
| 204 | + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { |
| 205 | + return nil, nil |
| 206 | + }, |
| 207 | + GetEncryptedClientHelloKeys: func(*tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) { |
| 208 | + return updatedKeys, nil |
| 209 | + }, |
| 210 | + } |
| 211 | + |
| 212 | + _, cancel := sqs.addState(updatedConfig) |
| 213 | + sqs.rmu.Lock() |
| 214 | + sqs.activeTlsConf = updatedConfig |
| 215 | + sqs.rmu.Unlock() |
| 216 | + |
| 217 | + keys, err = sqs.getEncryptedClientHelloKeys(hello) |
| 218 | + if err != nil { |
| 219 | + t.Fatalf("getting updated ECH keys: %v", err) |
| 220 | + } |
| 221 | + if !reflect.DeepEqual(keys, updatedKeys) { |
| 222 | + t.Fatalf("unexpected updated ECH keys: got %#v, want %#v", keys, updatedKeys) |
| 223 | + } |
| 224 | + |
| 225 | + cancel(nil) |
| 226 | + |
| 227 | + keys, err = sqs.getEncryptedClientHelloKeys(hello) |
| 228 | + if err != nil { |
| 229 | + t.Fatalf("getting restored ECH keys: %v", err) |
| 230 | + } |
| 231 | + if !reflect.DeepEqual(keys, initialKeys) { |
| 232 | + t.Fatalf("unexpected restored ECH keys: got %#v, want %#v", keys, initialKeys) |
| 233 | + } |
| 234 | +} |
| 235 | + |
178 | 236 | func TestParseNetworkAddress(t *testing.T) { |
179 | 237 | for i, tc := range []struct { |
180 | 238 | input string |
|
0 commit comments