diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 03a611b..04beb68 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Golang uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.26' - name: Install etcd if: runner.os == 'Linux' @@ -44,6 +44,15 @@ jobs: version: latest args: --verbose + - name: Test + run: | + go test -failfast -v -coverprofile profile.cov ./... + + - name: Coveralls + uses: coverallsapp/github-action@v2 + with: + file: profile.cov + - name: Test E2E if: runner.os == 'Linux' run: | diff --git a/checker/consul_leader_checker.go b/checker/consul_leader_checker.go index 1b8f878..250f0f1 100644 --- a/checker/consul_leader_checker.go +++ b/checker/consul_leader_checker.go @@ -23,7 +23,11 @@ func NewConsulLeaderChecker(con *vipconfig.Config) (lc *ConsulLeaderChecker, err url, err := url.Parse(con.Endpoints[0]) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse consul endpoint URL %s: %w", con.Endpoints[0], err) + } + + if url.Hostname() == "" { + return nil, fmt.Errorf("invalid consul endpoint URL: hostname is empty in %s", con.Endpoints[0]) } config := &api.Config{ @@ -34,7 +38,7 @@ func NewConsulLeaderChecker(con *vipconfig.Config) (lc *ConsulLeaderChecker, err } if lc.Client, err = api.NewClient(config); err != nil { - return nil, err + return nil, fmt.Errorf("failed to create consul client for endpoint %s: %w", con.Endpoints[0], err) } return lc, nil diff --git a/checker/consul_leader_checker_test.go b/checker/consul_leader_checker_test.go new file mode 100644 index 0000000..91dfb63 --- /dev/null +++ b/checker/consul_leader_checker_test.go @@ -0,0 +1,57 @@ +package checker + +import ( + "strings" + "testing" + + "github.com/cybertec-postgresql/vip-manager/vipconfig" + "go.uber.org/zap" +) + +func newTestConfig(endpoint string) *vipconfig.Config { + return &vipconfig.Config{ + Endpoints: []string{endpoint}, + Logger: zap.NewNop(), + } +} + +// TestNewConsulLeaderChecker_UnparseableURL verifies that a URL containing a +// null byte (rejected by net/url) is wrapped with context. +func TestNewConsulLeaderChecker_UnparseableURL(t *testing.T) { + t.Parallel() + _, err := NewConsulLeaderChecker(newTestConfig("http://invalid\x00host")) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to parse consul endpoint URL") { + t.Errorf("unexpected error message: %v", err) + } +} + +// TestNewConsulLeaderChecker_EmptyHostname verifies that a URL with no host +// component (e.g. a bare path) is rejected with the empty-hostname sentinel. +func TestNewConsulLeaderChecker_EmptyHostname(t *testing.T) { + t.Parallel() + // "localhost" without a scheme parses successfully but Hostname() == "" + _, err := NewConsulLeaderChecker(newTestConfig("localhost")) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "hostname is empty") { + t.Errorf("unexpected error message: %v", err) + } +} + +// TestNewConsulLeaderChecker_ValidURL verifies that a well-formed endpoint +// does not produce a construction error (api.NewClient never fails for valid +// address strings). +func TestNewConsulLeaderChecker_ValidURL(t *testing.T) { + t.Parallel() + lc, err := NewConsulLeaderChecker(newTestConfig("http://127.0.0.1:8500")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if lc == nil { + t.Fatal("expected non-nil checker") + } +} diff --git a/checker/etcd_leader_checker.go b/checker/etcd_leader_checker.go index 9155db9..f4b3c4e 100644 --- a/checker/etcd_leader_checker.go +++ b/checker/etcd_leader_checker.go @@ -23,7 +23,7 @@ type EtcdLeaderChecker struct { func NewEtcdLeaderChecker(conf *vipconfig.Config) (*EtcdLeaderChecker, error) { tlsConfig, err := getTransport(conf) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create TLS transport for etcd: %w", err) } cfg := clientv3.Config{ Endpoints: conf.Endpoints, @@ -35,7 +35,10 @@ func NewEtcdLeaderChecker(conf *vipconfig.Config) (*EtcdLeaderChecker, error) { Logger: conf.Logger, } c, err := clientv3.New(cfg) - return &EtcdLeaderChecker{conf, c}, err + if err != nil { + return nil, fmt.Errorf("failed to connect to etcd at endpoints %v: %w", conf.Endpoints, err) + } + return &EtcdLeaderChecker{conf, c}, nil } func getTransport(conf *vipconfig.Config) (*tls.Config, error) { @@ -74,7 +77,17 @@ func getTransport(conf *vipconfig.Config) (*tls.Config, error) { func (elc *EtcdLeaderChecker) get(ctx context.Context, out chan<- bool) { resp, err := elc.Get(ctx, elc.TriggerKey) if err != nil { - elc.Logger.Error("Failed to get etcd value:", zap.Error(err)) + elc.Logger.Error("Failed to get etcd value", zap.String("key", elc.TriggerKey), zap.Error(err)) + out <- false + return + } + if resp == nil { + elc.Logger.Error("Received nil response from etcd", zap.String("key", elc.TriggerKey)) + out <- false + return + } + if len(resp.Kvs) == 0 { + elc.Logger.Sugar().Info("No value found for key ", elc.TriggerKey, " - DCS may not have a leader yet") out <- false return } @@ -99,6 +112,7 @@ func (elc *EtcdLeaderChecker) watch(ctx context.Context, out chan<- bool) error continue } if err := watchResp.Err(); err != nil { + elc.Logger.Error("Watch error for key "+elc.TriggerKey+":", zap.Error(err)) elc.get(ctx, out) // RPC failed, try to get the key directly to be on the safe side continue } diff --git a/checker/etcd_leader_checker_test.go b/checker/etcd_leader_checker_test.go new file mode 100644 index 0000000..954e5a6 --- /dev/null +++ b/checker/etcd_leader_checker_test.go @@ -0,0 +1,137 @@ +package checker + +import ( + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/cybertec-postgresql/vip-manager/vipconfig" + "go.uber.org/zap" +) + +// certsDir returns the absolute path to the shared test certificates. +func certsDir() string { + _, file, _, _ := runtime.Caller(0) + return filepath.Join(filepath.Dir(file), "..", "test", "certs") +} + +func etcdConfig() *vipconfig.Config { + return &vipconfig.Config{ + Endpoints: []string{"http://127.0.0.1:2379"}, + Logger: zap.NewNop(), + } +} + +// --------------------------------------------------------------------------- +// getTransport +// --------------------------------------------------------------------------- + +// TestGetTransport_NoTLS verifies that an empty TLS config is accepted and +// returns a non-nil (but empty) *tls.Config. +func TestGetTransport_NoTLS(t *testing.T) { + t.Parallel() + cfg, err := getTransport(etcdConfig()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg == nil { + t.Fatal("expected non-nil tls.Config") + } +} + +// TestGetTransport_MissingCAFile verifies the error when the CA file path does +// not exist. +func TestGetTransport_MissingCAFile(t *testing.T) { + t.Parallel() + conf := etcdConfig() + conf.EtcdCAFile = "/nonexistent/ca.crt" + _, err := getTransport(conf) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "cannot load CA file") { + t.Errorf("unexpected error message: %v", err) + } +} + +// TestGetTransport_MissingCertFiles verifies the error when the client cert or +// key file is missing. +func TestGetTransport_MissingCertFiles(t *testing.T) { + t.Parallel() + conf := etcdConfig() + conf.EtcdCertFile = "/nonexistent/client.crt" + conf.EtcdKeyFile = "/nonexistent/client.key" + _, err := getTransport(conf) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "cannot load client cert or key file") { + t.Errorf("unexpected error message: %v", err) + } +} + +// TestGetTransport_ValidCAFile verifies that a real CA certificate file is +// loaded without error. +func TestGetTransport_ValidCAFile(t *testing.T) { + t.Parallel() + conf := etcdConfig() + conf.EtcdCAFile = filepath.Join(certsDir(), "etcd_server_ca.crt") + cfg, err := getTransport(conf) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.RootCAs == nil { + t.Error("expected RootCAs to be populated") + } +} + +// TestGetTransport_ValidCertAndKey verifies that a real client cert+key pair +// is loaded without error. +func TestGetTransport_ValidCertAndKey(t *testing.T) { + t.Parallel() + conf := etcdConfig() + conf.EtcdCAFile = filepath.Join(certsDir(), "etcd_server_ca.crt") + conf.EtcdCertFile = filepath.Join(certsDir(), "etcd_client.crt") + conf.EtcdKeyFile = filepath.Join(certsDir(), "etcd_client.key") + cfg, err := getTransport(conf) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cfg.Certificates) == 0 { + t.Error("expected certificates to be populated") + } +} + +// --------------------------------------------------------------------------- +// NewEtcdLeaderChecker +// --------------------------------------------------------------------------- + +// TestNewEtcdLeaderChecker_TLSError verifies that a TLS config error is +// wrapped with "failed to create TLS transport for etcd". +func TestNewEtcdLeaderChecker_TLSError(t *testing.T) { + t.Parallel() + conf := etcdConfig() + conf.EtcdCAFile = "/nonexistent/ca.crt" + _, err := NewEtcdLeaderChecker(conf) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to create TLS transport for etcd") { + t.Errorf("unexpected error message: %v", err) + } +} + +// TestNewEtcdLeaderChecker_ValidConfig verifies that the checker is created +// without error when endpoints and TLS are valid. The etcd client connects +// lazily so no live server is required. +func TestNewEtcdLeaderChecker_ValidConfig(t *testing.T) { + t.Parallel() + checker, err := NewEtcdLeaderChecker(etcdConfig()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if checker == nil { + t.Fatal("expected non-nil checker") + } +} diff --git a/checker/patroni_leader_checker.go b/checker/patroni_leader_checker.go index 2296a08..e34ed79 100644 --- a/checker/patroni_leader_checker.go +++ b/checker/patroni_leader_checker.go @@ -47,12 +47,17 @@ func (c *PatroniLeaderChecker) GetChangeNotificationStream(ctx context.Context, case <-ctx.Done(): return nil case <-time.After(time.Duration(c.Interval) * time.Millisecond): - r, err := c.Get(c.Endpoints[0] + c.TriggerKey) + url := c.Endpoints[0] + c.TriggerKey + r, err := c.Get(url) if err != nil { - c.Logger.Sugar().Error("patroni REST API error:", err) + c.Logger.Sugar().Errorf("patroni REST API error connecting to %s: %v", url, err) + out <- false continue } r.Body.Close() //throw away the body + if r.StatusCode < 200 || r.StatusCode >= 300 { + c.Logger.Sugar().Warnf("patroni REST API returned non-success status code %d for %s (expected %s)", r.StatusCode, url, c.TriggerValue) + } out <- strconv.Itoa(r.StatusCode) == c.TriggerValue } } diff --git a/checker/patroni_leader_checker_test.go b/checker/patroni_leader_checker_test.go new file mode 100644 index 0000000..ab37adc --- /dev/null +++ b/checker/patroni_leader_checker_test.go @@ -0,0 +1,136 @@ +package checker + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/cybertec-postgresql/vip-manager/vipconfig" + "go.uber.org/zap" +) + +func patroniConfig(endpoint, triggerKey, triggerValue string) *vipconfig.Config { + return &vipconfig.Config{ + Endpoints: []string{endpoint}, + TriggerKey: triggerKey, + TriggerValue: triggerValue, + Interval: 1, // 1 ms – fast for unit tests + Logger: zap.NewNop(), + } +} + +// runStream starts GetChangeNotificationStream in a goroutine and returns the +// first value emitted on out, canceling the context afterwards. Fails the test +// if no value arrives within 2 s. +func runStream(t *testing.T, conf *vipconfig.Config) bool { + t.Helper() + checker, err := NewPatroniLeaderChecker(conf) + if err != nil { + t.Fatalf("NewPatroniLeaderChecker: %v", err) + } + + out := make(chan bool, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { _ = checker.GetChangeNotificationStream(ctx, out) }() + + select { + case v := <-out: + cancel() + return v + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for stream value") + return false + } +} + +// --------------------------------------------------------------------------- +// NewPatroniLeaderChecker +// --------------------------------------------------------------------------- + +// TestNewPatroniLeaderChecker_TLSError ensures that a missing cert file causes +// construction to fail (error originates from getTransport). +func TestNewPatroniLeaderChecker_TLSError(t *testing.T) { + t.Parallel() + conf := patroniConfig("http://127.0.0.1:8008", "/leader", "200") + conf.EtcdCertFile = "/nonexistent/client.crt" + conf.EtcdKeyFile = "/nonexistent/client.key" + _, err := NewPatroniLeaderChecker(conf) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "cannot load client cert or key file") { + t.Errorf("unexpected error message: %v", err) + } +} + +// --------------------------------------------------------------------------- +// GetChangeNotificationStream +// --------------------------------------------------------------------------- + +// TestGetChangeNotificationStream_HTTPError verifies that a connection failure +// causes false to be sent on the output channel. +func TestGetChangeNotificationStream_HTTPError(t *testing.T) { + t.Parallel() + // Use a server that we close immediately so all requests get "connection refused". + srv := httptest.NewServer(http.NotFoundHandler()) + srv.Close() + + conf := patroniConfig(srv.URL, "/leader", "200") + result := runStream(t, conf) + if result != false { + t.Errorf("expected false on connection error, got true") + } +} + +// TestGetChangeNotificationStream_StatusMatch verifies that when the server +// returns the expected status code the stream emits true. +func TestGetChangeNotificationStream_StatusMatch(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) // 200 + })) + defer srv.Close() + + conf := patroniConfig(srv.URL, "/leader", "200") + if !runStream(t, conf) { + t.Error("expected true when status code matches trigger value") + } +} + +// TestGetChangeNotificationStream_StatusNoMatch verifies that a different +// status code causes false to be emitted. +func TestGetChangeNotificationStream_StatusNoMatch(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) // 503 + })) + defer srv.Close() + + conf := patroniConfig(srv.URL, "/leader", "200") + if runStream(t, conf) { + t.Error("expected false when status code does not match trigger value") + } +} + +// TestGetChangeNotificationStream_NonSuccessMatch verifies that a non-2xx +// status code that happens to equal the trigger value still emits true +// (the warning log does not prevent correct evaluation). +func TestGetChangeNotificationStream_NonSuccessMatch(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) // 503 + })) + defer srv.Close() + + // Patroni uses 503 to signal "not the leader" – but if an operator + // configures trigger-value=503 they expect true here. + conf := patroniConfig(srv.URL, "/leader", "503") + if !runStream(t, conf) { + t.Error("expected true when non-2xx status code matches trigger value") + } +} diff --git a/ipmanager/hetznerConfigurer.go b/ipmanager/hetznerConfigurer.go index 152d4c5..a89c479 100644 --- a/ipmanager/hetznerConfigurer.go +++ b/ipmanager/hetznerConfigurer.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "errors" + "fmt" "net" "os" "os/exec" @@ -40,17 +41,16 @@ func newHetznerConfigurer(config *IPConfiguration, verbose bool) (*HetznerConfig * In order to tell the Hetzner API to route the failover-ip to * this machine, we must attach our own IP address to the API request. */ -func getOutboundIP() net.IP { +func getOutboundIP() (net.IP, error) { conn, err := net.Dial("udp", "8.8.8.8:80") if err != nil || conn == nil { - log.Error("error dialing 8.8.8.8 to retrieve preferred outbound IP", err) - return nil + return nil, fmt.Errorf("error dialing 8.8.8.8 to retrieve preferred outbound IP: %w", err) } defer conn.Close() localAddr := conn.LocalAddr().(*net.UDPAddr) - return localAddr.IP + return localAddr.IP, nil } func (c *HetznerConfigurer) curlQueryFailover(post bool) (string, error) { @@ -98,10 +98,10 @@ func (c *HetznerConfigurer) curlQueryFailover(post bool) (string, error) { */ var cmd *exec.Cmd if post { - myOwnIP := getOutboundIP() - if myOwnIP == nil { - log.Error("Error determining this machine's IP address.") - return "", errors.New("error determining this machine's IP address") + myOwnIP, err := getOutboundIP() + if err != nil { + log.Error("Error determining this machine's IP address.", err) + return "", fmt.Errorf("error determining this machine's IP address: %w", err) } log.Infof("my_own_ip: %s\n", myOwnIP.String()) @@ -224,7 +224,14 @@ func (c *HetznerConfigurer) queryAddress() bool { c.cachedState = unknown } - if currentFailoverDestinationIP.Equal(getOutboundIP()) { + myOwnIP, err := getOutboundIP() + if err != nil { + log.Error("Error determining this machine's IP address.", err) + c.cachedState = unknown + return false + } + + if currentFailoverDestinationIP.Equal(myOwnIP) { //We "are" the current failover destination. c.cachedState = configured return true @@ -262,7 +269,14 @@ func (c *HetznerConfigurer) runAddressConfiguration() bool { c.lastAPICheck = time.Now() - if currentFailoverDestinationIP.Equal(getOutboundIP()) { + myOwnIP, err := getOutboundIP() + if err != nil { + log.Error("Error determining this machine's IP address.", err) + c.cachedState = unknown + return false + } + + if currentFailoverDestinationIP.Equal(myOwnIP) { //We "are" the current failover destination. log.Info("Failover was successfully executed!") c.cachedState = configured @@ -271,7 +285,7 @@ func (c *HetznerConfigurer) runAddressConfiguration() bool { log.Infof("The failover command was issued, but the current Failover destination (%s) is different from what it should be (%s).", currentFailoverDestinationIP.String(), - getOutboundIP().String()) + myOwnIP.String()) //Something must have gone wrong while trying to switch IP's... c.cachedState = unknown return false diff --git a/ipmanager/ip_manager.go b/ipmanager/ip_manager.go index 97a4143..ab786d6 100644 --- a/ipmanager/ip_manager.go +++ b/ipmanager/ip_manager.go @@ -2,6 +2,7 @@ package ipmanager import ( "context" + "fmt" "net" "net/netip" "sync/atomic" @@ -18,7 +19,7 @@ type ipConfigurer interface { getCIDR() string } -var log *zap.SugaredLogger = zap.L().Sugar() +var log *zap.SugaredLogger // IPManager implements the main functionality of the VIP manager type IPManager struct { @@ -40,22 +41,28 @@ func getMask(vip netip.Addr, mask int) net.IPMask { return net.CIDRMask(mask, 128) //IPv6 } -func getNetIface(iface string) *net.Interface { +func getNetIface(iface string) (*net.Interface, error) { netIface, err := net.InterfaceByName(iface) if err != nil { - log.Fatalf("Obtaining the interface raised an error: %s", err) + return nil, fmt.Errorf("failed to get interface %s: %w", iface, err) } - return netIface + if netIface.Flags&net.FlagUp == 0 { + return nil, fmt.Errorf("interface %s is not up", iface) + } + return netIface, nil } // NewIPManager returns a new instance of IPManager func NewIPManager(conf *vipconfig.Config, states <-chan bool) (m *IPManager, err error) { vip, err := netip.ParseAddr(conf.IP) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to parse VIP address: %w", err) } vipMask := getMask(vip, conf.Mask) - netIface := getNetIface(conf.Iface) + netIface, err := getNetIface(conf.Iface) + if err != nil { + return nil, err + } ipConf := &IPConfiguration{ VIP: vip, Netmask: vipMask, diff --git a/ipmanager/ip_manager_test.go b/ipmanager/ip_manager_test.go new file mode 100644 index 0000000..21ebbdb --- /dev/null +++ b/ipmanager/ip_manager_test.go @@ -0,0 +1,68 @@ +package ipmanager + +import ( + "strings" + "testing" + + "github.com/cybertec-postgresql/vip-manager/vipconfig" + "go.uber.org/zap" +) + +func minimalConfig(vip, iface string) *vipconfig.Config { + return &vipconfig.Config{ + IP: vip, + Mask: 24, + Iface: iface, + HostingType: "basic", + Logger: zap.NewNop(), + } +} + +// --------------------------------------------------------------------------- +// getNetIface +// --------------------------------------------------------------------------- + +// TestGetNetIface_Nonexistent verifies that requesting an interface that does +// not exist returns an error containing "failed to get interface". +func TestGetNetIface_Nonexistent(t *testing.T) { + t.Parallel() + _, err := getNetIface("definitely_nonexistent_interface_999") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to get interface") { + t.Errorf("unexpected error message: %v", err) + } +} + +// --------------------------------------------------------------------------- +// NewIPManager +// --------------------------------------------------------------------------- + +// TestNewIPManager_InvalidVIP verifies that a non-IP string is rejected with +// "failed to parse VIP address". +func TestNewIPManager_InvalidVIP(t *testing.T) { + t.Parallel() + states := make(chan bool) + _, err := NewIPManager(minimalConfig("not-an-ip-address", "lo"), states) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to parse VIP address") { + t.Errorf("unexpected error message: %v", err) + } +} + +// TestNewIPManager_InvalidInterface verifies that a valid VIP combined with a +// nonexistent interface name returns an error from getNetIface. +func TestNewIPManager_InvalidInterface(t *testing.T) { + t.Parallel() + states := make(chan bool) + _, err := NewIPManager(minimalConfig("10.0.0.1", "definitely_nonexistent_interface_999"), states) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to get interface") { + t.Errorf("unexpected error message: %v", err) + } +}