Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/paude-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ func main() {
// Credential store and token vendor
credStore, tokenVendor := buildCredentialStore(domainFilter)

// Start background hostname re-resolution (no-op if no hostnames configured)
clientFilter.StartResolving()

// Create and start proxy
srv := proxy.New(proxy.Config{
ListenAddr: listenAddr,
Expand All @@ -113,6 +116,7 @@ func main() {

<-done
log.Println("Shutting down...")
clientFilter.Stop()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
Expand Down
144 changes: 130 additions & 14 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,31 @@ func (bl *BlockedLogger) Close() error {
return bl.file.Close()
}

// ClientFilter validates client source IPs against an allowlist of IPs and CIDRs.
// A nil or empty ClientFilter allows all clients.
// ClientFilter validates client source IPs against an allowlist of IPs, CIDRs,
// and DNS hostnames. Hostnames are resolved to IPs at startup and periodically
// re-resolved in the background (every 30s) to handle dynamic IP assignments
// (e.g., Kubernetes pods restarting). A nil or empty ClientFilter allows all clients.
type ClientFilter struct {
ips []net.IP
nets []*net.IPNet
ips []net.IP
nets []*net.IPNet
hostnames []string
resolved map[string][]net.IP // hostname -> resolved IPs (protected by mu)
mu sync.RWMutex
stopCh chan struct{}
stopOnce sync.Once
}

// NewClientFilter parses a comma-separated list of IPs and CIDRs.
// Returns nil if the input is empty (allow all).
// NewClientFilter parses a comma-separated list of IPs, CIDRs, and DNS hostnames.
// Returns nil if the input is empty (allow all). Hostnames are resolved immediately;
// resolution failures are logged as warnings (the hostname may become resolvable later).
func NewClientFilter(s string) (*ClientFilter, error) {
if s == "" {
return nil, nil
}
cf := &ClientFilter{}
cf := &ClientFilter{
stopCh: make(chan struct{}),
resolved: make(map[string][]net.IP),
}
for _, part := range strings.Split(s, ",") {
part = strings.TrimSpace(part)
if part == "" {
Expand All @@ -130,18 +141,98 @@ func NewClientFilter(s string) (*ClientFilter, error) {
return nil, fmt.Errorf("invalid CIDR %q: %w", part, err)
}
cf.nets = append(cf.nets, ipNet)
} else {
ip := net.ParseIP(part)
if ip == nil {
return nil, fmt.Errorf("invalid IP %q", part)
}
} else if ip := net.ParseIP(part); ip != nil {
cf.ips = append(cf.ips, ip)
} else {
cf.hostnames = append(cf.hostnames, part)
}
}

if len(cf.hostnames) > 0 {
cf.resolveHostnames(true)
}

return cf, nil
}

// IsAllowed returns true if the given IP is in the allowlist.
// resolveHostnames resolves all configured hostnames and updates the resolved IP map.
// When initialResolve is true, all results are logged. Otherwise, only changes are logged.
func (cf *ClientFilter) resolveHostnames(initialResolve bool) {
newResolved := make(map[string][]net.IP, len(cf.hostnames))
for _, hostname := range cf.hostnames {
addrs, err := net.LookupHost(hostname)
if err != nil {
log.Printf("WARNING: failed to resolve allowed client hostname %q: %v", hostname, err)
continue
}
var ips []net.IP
for _, addr := range addrs {
if ip := net.ParseIP(addr); ip != nil {
ips = append(ips, ip)
}
}
newResolved[hostname] = ips
}

cf.mu.Lock()
old := cf.resolved
cf.resolved = newResolved
cf.mu.Unlock()

for _, hostname := range cf.hostnames {
newIPs := newResolved[hostname]
oldIPs := old[hostname]
if initialResolve || !ipsEqual(newIPs, oldIPs) {
ipStrs := make([]string, len(newIPs))
for i, ip := range newIPs {
ipStrs[i] = ip.String()
}
log.Printf("Resolved allowed client hostname %q -> %s", hostname, strings.Join(ipStrs, ", "))
}
}
}

// ipsEqual returns true if two IP slices contain the same IPs in the same order.
func ipsEqual(a, b []net.IP) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !a[i].Equal(b[i]) {
return false
}
}
return true
}

// StartResolving starts a background goroutine that re-resolves all hostname
// entries every 30 seconds to handle pods restarting with new IPs.
func (cf *ClientFilter) StartResolving() {
if cf == nil || len(cf.hostnames) == 0 {
return
}
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cf.resolveHostnames(false)
case <-cf.stopCh:
return
}
}
}()
}

// Stop stops the background hostname re-resolution goroutine. Safe to call multiple times.
func (cf *ClientFilter) Stop() {
if cf == nil || len(cf.hostnames) == 0 {
return
}
cf.stopOnce.Do(func() { close(cf.stopCh) })
}

func (cf *ClientFilter) IsAllowed(ip net.IP) bool {
if cf == nil {
return true
Expand All @@ -156,10 +247,19 @@ func (cf *ClientFilter) IsAllowed(ip net.IP) bool {
return true
}
}
cf.mu.RLock()
resolved := cf.resolved
cf.mu.RUnlock()
for _, ips := range resolved {
for _, allowed := range ips {
if allowed.Equal(ip) {
return true
}
}
}
return false
}

// String returns a human-readable representation of the filter.
func (cf *ClientFilter) String() string {
if cf == nil {
return "disabled (all clients allowed)"
Expand All @@ -171,6 +271,22 @@ func (cf *ClientFilter) String() string {
for _, ipNet := range cf.nets {
parts = append(parts, ipNet.String())
}
if len(cf.hostnames) > 0 {
cf.mu.RLock()
resolved := cf.resolved
cf.mu.RUnlock()
for _, hostname := range cf.hostnames {
if ips, ok := resolved[hostname]; ok && len(ips) > 0 {
ipStrs := make([]string, len(ips))
for i, ip := range ips {
ipStrs[i] = ip.String()
}
parts = append(parts, fmt.Sprintf("%s (resolved: %s)", hostname, strings.Join(ipStrs, ", ")))
} else {
parts = append(parts, fmt.Sprintf("%s (unresolved)", hostname))
}
}
}
return strings.Join(parts, ", ")
}

Expand Down
104 changes: 103 additions & 1 deletion internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func TestNewClientFilter(t *testing.T) {
{"CIDR", "10.0.0.0/24", false, false},
{"mixed", "10.0.0.1,172.16.0.0/12", false, false},
{"with spaces", " 10.0.0.1 , 10.0.0.2 ", false, false},
{"invalid IP", "notanip", false, true},
{"hostname", "notanip", false, false}, // treated as DNS hostname, not invalid
{"invalid CIDR", "10.0.0.0/99", false, true},
}
for _, tt := range tests {
Expand Down Expand Up @@ -215,3 +215,105 @@ func TestClientFilter_NilAllowsAll(t *testing.T) {
t.Error("nil ClientFilter should allow all IPs")
}
}

func TestNewClientFilter_Hostnames(t *testing.T) {
cf, err := NewClientFilter("localhost")
if err != nil {
t.Fatalf("NewClientFilter(\"localhost\") error: %v", err)
}
if cf == nil {
t.Fatal("expected non-nil ClientFilter")
}
if len(cf.hostnames) != 1 || cf.hostnames[0] != "localhost" {
t.Errorf("expected hostnames=[localhost], got %v", cf.hostnames)
}
if len(cf.resolved["localhost"]) == 0 {
t.Error("expected at least one resolved IP for localhost")
}
}

func TestNewClientFilter_MixedIPsHostnamesCIDRs(t *testing.T) {
cf, err := NewClientFilter("10.0.0.1,localhost,172.16.0.0/12")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cf.ips) != 1 {
t.Errorf("expected 1 static IP, got %d", len(cf.ips))
}
if len(cf.nets) != 1 {
t.Errorf("expected 1 CIDR, got %d", len(cf.nets))
}
if len(cf.hostnames) != 1 || cf.hostnames[0] != "localhost" {
t.Errorf("expected hostnames=[localhost], got %v", cf.hostnames)
}
}

func TestClientFilter_IsAllowed_WithResolvedIPs(t *testing.T) {
cf := &ClientFilter{
ips: []net.IP{net.ParseIP("10.0.0.1")},
resolved: map[string][]net.IP{"myhost": {net.ParseIP("192.168.1.100")}},
stopCh: make(chan struct{}),
}

// Static IP should match
if !cf.IsAllowed(net.ParseIP("10.0.0.1")) {
t.Error("static IP 10.0.0.1 should be allowed")
}
// Resolved IP should match
if !cf.IsAllowed(net.ParseIP("192.168.1.100")) {
t.Error("resolved IP 192.168.1.100 should be allowed")
}
// Unknown IP should not match
if cf.IsAllowed(net.ParseIP("10.0.0.2")) {
t.Error("unknown IP 10.0.0.2 should not be allowed")
}
}

func TestClientFilter_IsAllowed_Localhost(t *testing.T) {
cf, err := NewClientFilter("localhost")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// localhost should resolve to 127.0.0.1 and/or ::1
if !cf.IsAllowed(net.ParseIP("127.0.0.1")) && !cf.IsAllowed(net.ParseIP("::1")) {
t.Error("expected localhost to resolve to 127.0.0.1 or ::1")
}
}

func TestClientFilter_StopNilSafe(t *testing.T) {
var cf *ClientFilter
cf.Stop() // nil — should not panic

cf2, _ := NewClientFilter("10.0.0.1")
cf2.Stop() // no hostnames — should not panic
}

func TestClientFilter_StopDoubleCall(t *testing.T) {
cf, _ := NewClientFilter("localhost")
cf.StartResolving()
cf.Stop()
cf.Stop() // second call — should not panic
}

func TestClientFilter_String_WithHostnames(t *testing.T) {
cf, err := NewClientFilter("10.0.0.1,localhost")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
s := cf.String()
if !strings.Contains(s, "10.0.0.1") {
t.Errorf("String() should contain static IP, got %q", s)
}
if !strings.Contains(s, "localhost") {
t.Errorf("String() should contain hostname, got %q", s)
}
if !strings.Contains(s, "resolved:") {
t.Errorf("String() should contain resolved IPs for localhost, got %q", s)
}
}

func TestClientFilter_StartResolving_NilSafe(t *testing.T) {
// StartResolving on nil should not panic
var cf *ClientFilter
cf.StartResolving()
}