Skip to content

Commit a938682

Browse files
committed
security: allow skip-ca without certificates to continue building tls config (#951)
1 parent 9aba8c4 commit a938682

5 files changed

Lines changed: 33 additions & 141 deletions

File tree

lib/util/security/cert.go

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -234,14 +234,7 @@ func (ci *CertInfo) buildClientConfig(lg *zap.Logger) (*tls.Config, error) {
234234
lg.Warn("specified auto-certs in a client tls config, ignored")
235235
}
236236

237-
if !cfg.HasCA() {
238-
if cfg.SkipCA {
239-
// still enable TLS without verify server certs
240-
return &tls.Config{
241-
InsecureSkipVerify: true,
242-
MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg),
243-
}, nil
244-
}
237+
if !cfg.HasCA() && !cfg.SkipCA {
245238
lg.Debug("no CA to verify server connections, disable TLS")
246239
return nil, nil
247240
}
@@ -251,30 +244,32 @@ func (ci *CertInfo) buildClientConfig(lg *zap.Logger) (*tls.Config, error) {
251244
GetCertificate: ci.getCert,
252245
GetClientCertificate: ci.getClientCert,
253246
InsecureSkipVerify: true,
254-
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
255-
return ci.verifyCA(rawCerts)
256-
},
257247
}
258248

259-
caPEM, err := os.ReadFile(cfg.CA)
260-
if err != nil {
261-
return nil, err
262-
}
263-
certPool := x509.NewCertPool()
264-
if !certPool.AppendCertsFromPEM(caPEM) {
265-
return nil, errors.New("failed to append ca certs")
249+
if cfg.HasCA() {
250+
tcfg.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
251+
return ci.verifyCA(rawCerts)
252+
}
253+
caPEM, err := os.ReadFile(cfg.CA)
254+
if err != nil {
255+
return nil, err
256+
}
257+
certPool := x509.NewCertPool()
258+
if !certPool.AppendCertsFromPEM(caPEM) {
259+
return nil, errors.New("failed to append ca certs")
260+
}
261+
ci.ca.Store(certPool)
262+
tcfg.RootCAs = certPool
266263
}
267-
ci.ca.Store(certPool)
268-
tcfg.RootCAs = certPool
269264

270-
if !cfg.HasCert() {
265+
if cfg.Cert == "" || cfg.Key == "" {
271266
lg.Debug("no certificates, server may reject the connection")
272267
return tcfg, nil
273268
}
274269

275270
cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
276271
if err != nil {
277-
return nil, errors.WithStack(err)
272+
return nil, err
278273
}
279274
ci.cert.Store(&cert)
280275

lib/util/security/cert_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ func TestCertServer(t *testing.T) {
187187
require.Nil(t, c.RootCAs)
188188
require.Nil(t, ci.cert.Load())
189189
require.Equal(t, tls.VersionTLS12, int(c.MinVersion))
190+
require.NotNil(t, c.GetClientCertificate, "skip-ca should set GetClientCertificate")
190191
},
191192
},
192193
{
@@ -336,6 +337,7 @@ func TestSetConfig(t *testing.T) {
336337
require.NoError(t, err)
337338
require.NotNil(t, tcfg)
338339
require.True(t, tcfg.InsecureSkipVerify)
340+
require.NotNil(t, tcfg.GetClientCertificate, "skip-ca should set GetClientCertificate")
339341

340342
cfg = config.TLSConfig{
341343
SkipCA: false,

lib/util/security/tls.go

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -204,28 +204,24 @@ func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Con
204204

205205
func BuildClientTLSConfig(logger *zap.Logger, cfg config.TLSConfig) (*tls.Config, error) {
206206
logger = logger.With(zap.String("tls", "client"))
207-
if !cfg.HasCA() {
208-
if cfg.SkipCA {
209-
// still enable TLS without verify server certs
210-
return &tls.Config{
211-
InsecureSkipVerify: true,
212-
MinVersion: tls.VersionTLS11,
213-
}, nil
214-
}
207+
if !cfg.HasCA() && !cfg.SkipCA {
215208
logger.Info("no CA to verify server connections, disable TLS")
216209
return nil, nil
217210
}
218211

219212
tcfg := &tls.Config{
220213
MinVersion: tls.VersionTLS11,
214+
InsecureSkipVerify: cfg.SkipCA,
221215
}
222-
tcfg.RootCAs = x509.NewCertPool()
223-
certBytes, err := os.ReadFile(cfg.CA)
224-
if err != nil {
225-
return nil, errors.Errorf("failed to read CA: %w", err)
226-
}
227-
if !tcfg.RootCAs.AppendCertsFromPEM(certBytes) {
228-
return nil, errors.Errorf("failed to append CA")
216+
if cfg.HasCA() {
217+
tcfg.RootCAs = x509.NewCertPool()
218+
certBytes, err := os.ReadFile(cfg.CA)
219+
if err != nil {
220+
return nil, errors.Errorf("failed to read CA: %w", err)
221+
}
222+
if !tcfg.RootCAs.AppendCertsFromPEM(certBytes) {
223+
return nil, errors.Errorf("failed to append CA")
224+
}
229225
}
230226

231227
if !cfg.HasCert() {

pkg/proxy/net/packetio.go

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import (
2828
"crypto/tls"
2929
"io"
3030
"net"
31-
"sync"
3231
"time"
3332

3433
"github.com/pingcap/tiproxy/lib/config"
@@ -43,19 +42,6 @@ var (
4342
ErrInvalidSequence = errors.New("invalid sequence")
4443
)
4544

46-
var (
47-
readerPool = sync.Pool{
48-
New: func() any {
49-
return bufio.NewReaderSize(nil, DefaultConnBufferSize)
50-
},
51-
}
52-
writerPool = sync.Pool{
53-
New: func() any {
54-
return bufio.NewWriterSize(nil, DefaultConnBufferSize)
55-
},
56-
}
57-
)
58-
5945
const (
6046
DefaultConnBufferSize = 32 * 1024
6147
)
@@ -100,27 +86,16 @@ type basicReadWriter struct {
10086
inBytes uint64
10187
outBytes uint64
10288
sequence uint8
103-
pooled bool
10489
}
10590

10691
func newBasicReadWriter(conn net.Conn, bufferSize int) *basicReadWriter {
10792
if bufferSize == 0 {
10893
bufferSize = DefaultConnBufferSize
10994
}
110-
brw := &basicReadWriter{
111-
Conn: conn,
112-
}
113-
if bufferSize == DefaultConnBufferSize {
114-
r := readerPool.Get().(*bufio.Reader)
115-
r.Reset(conn)
116-
w := writerPool.Get().(*bufio.Writer)
117-
w.Reset(conn)
118-
brw.ReadWriter = bufio.NewReadWriter(r, w)
119-
brw.pooled = true
120-
} else {
121-
brw.ReadWriter = bufio.NewReadWriter(bufio.NewReaderSize(conn, bufferSize), bufio.NewWriterSize(conn, bufferSize))
95+
return &basicReadWriter{
96+
Conn: conn,
97+
ReadWriter: bufio.NewReadWriter(bufio.NewReaderSize(conn, bufferSize), bufio.NewWriterSize(conn, bufferSize)),
12298
}
123-
return brw
12499
}
125100

126101
func (brw *basicReadWriter) Read(b []byte) (n int, err error) {
@@ -178,16 +153,6 @@ func (brw *basicReadWriter) ResetSequence() {
178153
brw.sequence = 0
179154
}
180155

181-
func (brw *basicReadWriter) Free() {
182-
if brw.pooled {
183-
brw.pooled = false
184-
brw.ReadWriter.Reader.Reset(nil)
185-
brw.ReadWriter.Writer.Reset(nil)
186-
readerPool.Put(brw.ReadWriter.Reader)
187-
writerPool.Put(brw.ReadWriter.Writer)
188-
}
189-
}
190-
191156
func (brw *basicReadWriter) TLSConnectionState() tls.ConnectionState {
192157
return tls.ConnectionState{}
193158
}
@@ -523,19 +488,6 @@ func (p *packetIO) GracefulClose() error {
523488
return nil
524489
}
525490

526-
func freeBasicReadWriter(rw packetReadWriter) {
527-
switch v := rw.(type) {
528-
case *basicReadWriter:
529-
v.Free()
530-
case *tlsReadWriter:
531-
freeBasicReadWriter(v.packetReadWriter)
532-
case *compressedReadWriter:
533-
freeBasicReadWriter(v.packetReadWriter)
534-
case *proxyReadWriter:
535-
freeBasicReadWriter(v.packetReadWriter)
536-
}
537-
}
538-
539491
func (p *packetIO) Close() error {
540492
var errs []error
541493
/*
@@ -544,7 +496,6 @@ func (p *packetIO) Close() error {
544496
errs = append(errs, err)
545497
}
546498
*/
547-
freeBasicReadWriter(p.readWriter)
548499
if err := p.readWriter.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
549500
errs = append(errs, errors.WithStack(err))
550501
}

pkg/proxy/net/packetio_test.go

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -719,55 +719,3 @@ func runForwardBenchmark(b *testing.B, f func(packetIO1, packetIO2 *packetIO)) {
719719
_ = packetIO2.Close()
720720
wg.Wait()
721721
}
722-
723-
func TestPacketIOPooling(t *testing.T) {
724-
testTCPConn(t,
725-
func(t *testing.T, cli *packetIO) {
726-
brw, ok := cli.readWriter.(*basicReadWriter)
727-
require.True(t, ok)
728-
require.True(t, brw.pooled, "pooled flag should be true for default buffer size")
729-
730-
require.NoError(t, cli.WritePacket([]byte("pooltest"), true))
731-
},
732-
func(t *testing.T, srv *packetIO) {
733-
brw, ok := srv.readWriter.(*basicReadWriter)
734-
require.True(t, ok)
735-
require.True(t, brw.pooled, "pooled flag should be true for default buffer size")
736-
737-
data, err := srv.ReadPacket()
738-
require.NoError(t, err)
739-
require.Equal(t, []byte("pooltest"), data)
740-
},
741-
1,
742-
)
743-
744-
lg, _ := logger.CreateLoggerForTest(t)
745-
cli, srv := net.Pipe()
746-
cliIO := NewPacketIO(cli, lg, DefaultConnBufferSize*2) // non-default
747-
srvIO := NewPacketIO(srv, lg, DefaultConnBufferSize*2)
748-
brw, ok := cliIO.readWriter.(*basicReadWriter)
749-
require.True(t, ok)
750-
require.False(t, brw.pooled, "pooled flag should be false for non-default buffer size")
751-
_ = cliIO.Close()
752-
_ = srvIO.Close()
753-
754-
testTCPConn(t,
755-
func(t *testing.T, cli *packetIO) {
756-
require.NoError(t, cli.Close())
757-
require.NoError(t, cli.Close())
758-
},
759-
func(t *testing.T, srv *packetIO) {
760-
require.NoError(t, srv.Close())
761-
require.NoError(t, srv.Close())
762-
},
763-
1,
764-
)
765-
766-
for i := 0; i < 100; i++ {
767-
c1, c2 := net.Pipe()
768-
p1 := NewPacketIO(c1, lg, DefaultConnBufferSize)
769-
p2 := NewPacketIO(c2, lg, DefaultConnBufferSize)
770-
_ = p1.Close()
771-
_ = p2.Close()
772-
}
773-
}

0 commit comments

Comments
 (0)