Skip to content

Commit b7e7ab8

Browse files
authored
Merge pull request #65 from DCsunset/gotify-root-ca
Add environment variable to set root CA for TLS verification
2 parents ae976e2 + 9a0ffa5 commit b7e7ab8

3 files changed

Lines changed: 195 additions & 1 deletion

File tree

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ If needed, you can disable SSL handcheck validation using an environment variabl
205205
export GOTIFY_SKIP_VERIFY_TLS=True
206206
```
207207

208+
For better security with self-signed certificate, you can also set custom root CA or pin the server cert for TLS verification:
209+
```
210+
export SSL_CERT_FILE=/path/to/cert.pem
211+
```
212+
208213

209214
### Dockerfile
210215
The Dockerfile contains the steps necessary to build a new version of the CLI and then run it in

utils/createhttpclient.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,35 @@ package utils
22

33
import (
44
"crypto/tls"
5+
"crypto/x509"
56
"net/http"
67
"os"
78
"strings"
89
)
910

1011
func CreateHTTPClient() *http.Client {
1112
skipVerify := strings.ToLower(os.Getenv("GOTIFY_SKIP_VERIFY_TLS")) == "true"
13+
certFile := os.Getenv("SSL_CERT_FILE")
14+
if skipVerify && certFile != "" {
15+
Exit1With("GOTIFY_SKIP_VERIFY_TLS and SSL_CERT_FILE shouldn't be set at the same time")
16+
}
17+
1218
customTransport := http.DefaultTransport.(*http.Transport).Clone()
13-
customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: skipVerify}
19+
rootCAs := customTransport.TLSClientConfig.RootCAs
20+
if certFile != "" {
21+
cert, err := os.ReadFile(certFile)
22+
if err != nil {
23+
Exit1With("Failed to read cert:", err)
24+
}
25+
rootCAs = x509.NewCertPool()
26+
ok := rootCAs.AppendCertsFromPEM(cert)
27+
if !ok {
28+
Exit1With("Failed to parse cert", certFile)
29+
}
30+
}
31+
customTransport.TLSClientConfig = &tls.Config{
32+
InsecureSkipVerify: skipVerify,
33+
RootCAs: rootCAs,
34+
}
1435
return &http.Client{Transport: customTransport}
1536
}

utils/createhttpclient_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package utils
2+
3+
import (
4+
"context"
5+
"crypto/ed25519"
6+
"crypto/rand"
7+
"crypto/tls"
8+
"crypto/x509"
9+
"encoding/pem"
10+
"errors"
11+
"math/big"
12+
"net"
13+
"net/http"
14+
"os"
15+
"sync"
16+
"sync/atomic"
17+
"testing"
18+
"time"
19+
)
20+
21+
func newCA(t *testing.T) ([]byte, func(domain string) (ed25519.PublicKey, ed25519.PrivateKey)) {
22+
caPubKey, caPrivKey, err := ed25519.GenerateKey(rand.Reader)
23+
if err != nil {
24+
t.Fatalf("failed to generate key: %v", err)
25+
}
26+
cert, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
27+
SerialNumber: big.NewInt(1),
28+
NotAfter: time.Now().Add(time.Hour),
29+
IsCA: true,
30+
BasicConstraintsValid: true,
31+
}, &x509.Certificate{}, caPubKey, caPrivKey)
32+
33+
if err != nil {
34+
t.Fatalf("failed to create certificate: %v", err)
35+
}
36+
37+
certParsed, err := x509.ParseCertificate(cert)
38+
if err != nil {
39+
t.Fatalf("failed to parse certificate: %v", err)
40+
}
41+
42+
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert}), func(domain string) (ed25519.PublicKey, ed25519.PrivateKey) {
43+
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
44+
if err != nil {
45+
t.Fatalf("failed to generate key: %v", err)
46+
}
47+
48+
cert, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
49+
DNSNames: []string{domain},
50+
SerialNumber: big.NewInt(2),
51+
NotAfter: time.Now().Add(time.Hour),
52+
}, certParsed, pubKey, caPrivKey)
53+
54+
if err != nil {
55+
t.Fatalf("failed to create certificate: %v", err)
56+
}
57+
58+
privPEM, err := x509.MarshalPKCS8PrivateKey(privKey)
59+
if err != nil {
60+
t.Fatalf("failed to marshal private key: %v", err)
61+
}
62+
63+
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert}),
64+
pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privPEM})
65+
}
66+
}
67+
68+
func TestCreateHTTPClient(t *testing.T) {
69+
caPEM, signer := newCA(t)
70+
wrongCAPEM, wrongSigner := newCA(t)
71+
72+
certPEM, certPriv := signer("gotify.local")
73+
wrongDomainPEM, wrongDomainPriv := signer("gotify.invalid")
74+
wrongCAPEM, wrongCAPriv := wrongSigner("gotify.local")
75+
76+
testTrust := func(trustCert []byte, serverPEM []byte, serverKey []byte) bool {
77+
serverSide, clientSide := net.Pipe()
78+
79+
serverCert, err := tls.X509KeyPair(serverPEM, serverKey)
80+
if err != nil {
81+
panic(err)
82+
}
83+
84+
tlsServer := tls.Server(serverSide, &tls.Config{
85+
Certificates: []tls.Certificate{
86+
serverCert,
87+
},
88+
})
89+
90+
var certFile *os.File = nil
91+
if trustCert != nil {
92+
var err error
93+
certFile, err = os.CreateTemp("", "GotifyTrustCert")
94+
if err != nil {
95+
t.Fatalf("Failed to create temp file: %v", err)
96+
}
97+
certFile.Write(trustCert)
98+
certFile.Close()
99+
os.Setenv("SSL_CERT_FILE", certFile.Name())
100+
}
101+
102+
client := CreateHTTPClient()
103+
client.Transport.(*http.Transport).DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
104+
return clientSide, nil
105+
}
106+
107+
os.Unsetenv("SSL_CERT_FILE")
108+
if certFile != nil {
109+
os.Remove(certFile.Name())
110+
}
111+
112+
var failed uint32 = 0
113+
var unexpected error
114+
115+
wg := sync.WaitGroup{}
116+
wg.Add(2)
117+
118+
go func() {
119+
defer serverSide.Close()
120+
defer wg.Done()
121+
122+
if err := tlsServer.Handshake(); err == nil {
123+
tlsServer.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
124+
}
125+
tlsServer.Close()
126+
}()
127+
128+
go func() {
129+
defer clientSide.Close()
130+
defer wg.Done()
131+
132+
if _, err := client.Get("https://gotify.local"); err != nil {
133+
if _, ok := errors.Unwrap(err).(*tls.CertificateVerificationError); ok {
134+
atomic.StoreUint32(&failed, 1)
135+
} else {
136+
unexpected = err
137+
}
138+
}
139+
}()
140+
141+
wg.Wait()
142+
if unexpected != nil {
143+
t.Fatal(unexpected)
144+
}
145+
146+
return atomic.LoadUint32(&failed) == 0
147+
}
148+
149+
if !testTrust(certPEM, certPEM, certPriv) {
150+
t.Fatal("failed to trust valid server cert")
151+
}
152+
153+
if !testTrust(caPEM, certPEM, certPriv) {
154+
t.Fatal("failed to trust valid CA")
155+
}
156+
157+
if testTrust(caPEM, wrongCAPEM, wrongCAPriv) {
158+
t.Fatal("trusted invalid cert")
159+
}
160+
161+
if testTrust(caPEM, wrongDomainPEM, wrongDomainPriv) {
162+
t.Fatal("trusted cert with invalid domain")
163+
}
164+
165+
if testTrust(nil, certPEM, certPriv) {
166+
t.Fatal("shouldn't trust server cert")
167+
}
168+
}

0 commit comments

Comments
 (0)