diff --git a/cmd/paude-proxy/main.go b/cmd/paude-proxy/main.go index f693a56..1803fdc 100644 --- a/cmd/paude-proxy/main.go +++ b/cmd/paude-proxy/main.go @@ -87,6 +87,9 @@ func main() { // Credential store and token vendor credStore, tokenVendor := buildCredentialStore(domainFilter) + // Start background hostname re-resolution (no-op if no hostnames configured) + clientFilter.StartResolving() + // Create and start proxy srv := proxy.New(proxy.Config{ ListenAddr: listenAddr, @@ -113,6 +116,7 @@ func main() { <-done log.Println("Shutting down...") + clientFilter.Stop() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 185e23c..edfa477 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -105,20 +105,31 @@ func (bl *BlockedLogger) Close() error { return bl.file.Close() } -// ClientFilter validates client source IPs against an allowlist of IPs and CIDRs. -// A nil or empty ClientFilter allows all clients. +// ClientFilter validates client source IPs against an allowlist of IPs, CIDRs, +// and DNS hostnames. Hostnames are resolved to IPs at startup and periodically +// re-resolved in the background (every 30s) to handle dynamic IP assignments +// (e.g., Kubernetes pods restarting). A nil or empty ClientFilter allows all clients. type ClientFilter struct { - ips []net.IP - nets []*net.IPNet + ips []net.IP + nets []*net.IPNet + hostnames []string + resolved map[string][]net.IP // hostname -> resolved IPs (protected by mu) + mu sync.RWMutex + stopCh chan struct{} + stopOnce sync.Once } -// NewClientFilter parses a comma-separated list of IPs and CIDRs. -// Returns nil if the input is empty (allow all). +// NewClientFilter parses a comma-separated list of IPs, CIDRs, and DNS hostnames. +// Returns nil if the input is empty (allow all). Hostnames are resolved immediately; +// resolution failures are logged as warnings (the hostname may become resolvable later). func NewClientFilter(s string) (*ClientFilter, error) { if s == "" { return nil, nil } - cf := &ClientFilter{} + cf := &ClientFilter{ + stopCh: make(chan struct{}), + resolved: make(map[string][]net.IP), + } for _, part := range strings.Split(s, ",") { part = strings.TrimSpace(part) if part == "" { @@ -130,18 +141,98 @@ func NewClientFilter(s string) (*ClientFilter, error) { return nil, fmt.Errorf("invalid CIDR %q: %w", part, err) } cf.nets = append(cf.nets, ipNet) - } else { - ip := net.ParseIP(part) - if ip == nil { - return nil, fmt.Errorf("invalid IP %q", part) - } + } else if ip := net.ParseIP(part); ip != nil { cf.ips = append(cf.ips, ip) + } else { + cf.hostnames = append(cf.hostnames, part) } } + + if len(cf.hostnames) > 0 { + cf.resolveHostnames(true) + } + return cf, nil } -// IsAllowed returns true if the given IP is in the allowlist. +// resolveHostnames resolves all configured hostnames and updates the resolved IP map. +// When initialResolve is true, all results are logged. Otherwise, only changes are logged. +func (cf *ClientFilter) resolveHostnames(initialResolve bool) { + newResolved := make(map[string][]net.IP, len(cf.hostnames)) + for _, hostname := range cf.hostnames { + addrs, err := net.LookupHost(hostname) + if err != nil { + log.Printf("WARNING: failed to resolve allowed client hostname %q: %v", hostname, err) + continue + } + var ips []net.IP + for _, addr := range addrs { + if ip := net.ParseIP(addr); ip != nil { + ips = append(ips, ip) + } + } + newResolved[hostname] = ips + } + + cf.mu.Lock() + old := cf.resolved + cf.resolved = newResolved + cf.mu.Unlock() + + for _, hostname := range cf.hostnames { + newIPs := newResolved[hostname] + oldIPs := old[hostname] + if initialResolve || !ipsEqual(newIPs, oldIPs) { + ipStrs := make([]string, len(newIPs)) + for i, ip := range newIPs { + ipStrs[i] = ip.String() + } + log.Printf("Resolved allowed client hostname %q -> %s", hostname, strings.Join(ipStrs, ", ")) + } + } +} + +// ipsEqual returns true if two IP slices contain the same IPs in the same order. +func ipsEqual(a, b []net.IP) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !a[i].Equal(b[i]) { + return false + } + } + return true +} + +// StartResolving starts a background goroutine that re-resolves all hostname +// entries every 30 seconds to handle pods restarting with new IPs. +func (cf *ClientFilter) StartResolving() { + if cf == nil || len(cf.hostnames) == 0 { + return + } + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + cf.resolveHostnames(false) + case <-cf.stopCh: + return + } + } + }() +} + +// Stop stops the background hostname re-resolution goroutine. Safe to call multiple times. +func (cf *ClientFilter) Stop() { + if cf == nil || len(cf.hostnames) == 0 { + return + } + cf.stopOnce.Do(func() { close(cf.stopCh) }) +} + func (cf *ClientFilter) IsAllowed(ip net.IP) bool { if cf == nil { return true @@ -156,10 +247,19 @@ func (cf *ClientFilter) IsAllowed(ip net.IP) bool { return true } } + cf.mu.RLock() + resolved := cf.resolved + cf.mu.RUnlock() + for _, ips := range resolved { + for _, allowed := range ips { + if allowed.Equal(ip) { + return true + } + } + } return false } -// String returns a human-readable representation of the filter. func (cf *ClientFilter) String() string { if cf == nil { return "disabled (all clients allowed)" @@ -171,6 +271,22 @@ func (cf *ClientFilter) String() string { for _, ipNet := range cf.nets { parts = append(parts, ipNet.String()) } + if len(cf.hostnames) > 0 { + cf.mu.RLock() + resolved := cf.resolved + cf.mu.RUnlock() + for _, hostname := range cf.hostnames { + if ips, ok := resolved[hostname]; ok && len(ips) > 0 { + ipStrs := make([]string, len(ips)) + for i, ip := range ips { + ipStrs[i] = ip.String() + } + parts = append(parts, fmt.Sprintf("%s (resolved: %s)", hostname, strings.Join(ipStrs, ", "))) + } else { + parts = append(parts, fmt.Sprintf("%s (unresolved)", hostname)) + } + } + } return strings.Join(parts, ", ") } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index f27f92f..d018c7e 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -157,7 +157,7 @@ func TestNewClientFilter(t *testing.T) { {"CIDR", "10.0.0.0/24", false, false}, {"mixed", "10.0.0.1,172.16.0.0/12", false, false}, {"with spaces", " 10.0.0.1 , 10.0.0.2 ", false, false}, - {"invalid IP", "notanip", false, true}, + {"hostname", "notanip", false, false}, // treated as DNS hostname, not invalid {"invalid CIDR", "10.0.0.0/99", false, true}, } for _, tt := range tests { @@ -215,3 +215,105 @@ func TestClientFilter_NilAllowsAll(t *testing.T) { t.Error("nil ClientFilter should allow all IPs") } } + +func TestNewClientFilter_Hostnames(t *testing.T) { + cf, err := NewClientFilter("localhost") + if err != nil { + t.Fatalf("NewClientFilter(\"localhost\") error: %v", err) + } + if cf == nil { + t.Fatal("expected non-nil ClientFilter") + } + if len(cf.hostnames) != 1 || cf.hostnames[0] != "localhost" { + t.Errorf("expected hostnames=[localhost], got %v", cf.hostnames) + } + if len(cf.resolved["localhost"]) == 0 { + t.Error("expected at least one resolved IP for localhost") + } +} + +func TestNewClientFilter_MixedIPsHostnamesCIDRs(t *testing.T) { + cf, err := NewClientFilter("10.0.0.1,localhost,172.16.0.0/12") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(cf.ips) != 1 { + t.Errorf("expected 1 static IP, got %d", len(cf.ips)) + } + if len(cf.nets) != 1 { + t.Errorf("expected 1 CIDR, got %d", len(cf.nets)) + } + if len(cf.hostnames) != 1 || cf.hostnames[0] != "localhost" { + t.Errorf("expected hostnames=[localhost], got %v", cf.hostnames) + } +} + +func TestClientFilter_IsAllowed_WithResolvedIPs(t *testing.T) { + cf := &ClientFilter{ + ips: []net.IP{net.ParseIP("10.0.0.1")}, + resolved: map[string][]net.IP{"myhost": {net.ParseIP("192.168.1.100")}}, + stopCh: make(chan struct{}), + } + + // Static IP should match + if !cf.IsAllowed(net.ParseIP("10.0.0.1")) { + t.Error("static IP 10.0.0.1 should be allowed") + } + // Resolved IP should match + if !cf.IsAllowed(net.ParseIP("192.168.1.100")) { + t.Error("resolved IP 192.168.1.100 should be allowed") + } + // Unknown IP should not match + if cf.IsAllowed(net.ParseIP("10.0.0.2")) { + t.Error("unknown IP 10.0.0.2 should not be allowed") + } +} + +func TestClientFilter_IsAllowed_Localhost(t *testing.T) { + cf, err := NewClientFilter("localhost") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // localhost should resolve to 127.0.0.1 and/or ::1 + if !cf.IsAllowed(net.ParseIP("127.0.0.1")) && !cf.IsAllowed(net.ParseIP("::1")) { + t.Error("expected localhost to resolve to 127.0.0.1 or ::1") + } +} + +func TestClientFilter_StopNilSafe(t *testing.T) { + var cf *ClientFilter + cf.Stop() // nil — should not panic + + cf2, _ := NewClientFilter("10.0.0.1") + cf2.Stop() // no hostnames — should not panic +} + +func TestClientFilter_StopDoubleCall(t *testing.T) { + cf, _ := NewClientFilter("localhost") + cf.StartResolving() + cf.Stop() + cf.Stop() // second call — should not panic +} + +func TestClientFilter_String_WithHostnames(t *testing.T) { + cf, err := NewClientFilter("10.0.0.1,localhost") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + s := cf.String() + if !strings.Contains(s, "10.0.0.1") { + t.Errorf("String() should contain static IP, got %q", s) + } + if !strings.Contains(s, "localhost") { + t.Errorf("String() should contain hostname, got %q", s) + } + if !strings.Contains(s, "resolved:") { + t.Errorf("String() should contain resolved IPs for localhost, got %q", s) + } +} + +func TestClientFilter_StartResolving_NilSafe(t *testing.T) { + // StartResolving on nil should not panic + var cf *ClientFilter + cf.StartResolving() +}