diff --git a/net/egress/dns.go b/net/egress/dns.go index 4046445..fe3e9b5 100644 --- a/net/egress/dns.go +++ b/net/egress/dns.go @@ -6,10 +6,21 @@ package egress import ( "fmt" "net" + "strings" "github.com/miekg/dns" ) +// normalizeDNSName returns s lower-cased with a single trailing dot stripped, +// so that names from DNS wire format (often FQDN with trailing dot) compare +// consistently against caller-supplied policy names (typically without a +// trailing dot). DNS names are case-insensitive by RFC 1035. +func normalizeDNSName(s string) string { + s = strings.ToLower(s) + s = strings.TrimSuffix(s, ".") + return s +} + // ParseDNSQuery extracts the transaction ID, question name, and query type // from a DNS query payload (UDP payload, not including Ethernet/IP/UDP headers). func ParseDNSQuery(payload []byte) (txnID uint16, qname string, qtype uint16, err error) { @@ -52,13 +63,48 @@ func BuildNXDOMAIN(txnID uint16, qname string, qtype uint16) ([]byte, error) { // ParseDNSResponse extracts the question name, A-record IPs, and minimum // TTL from a DNS response payload. +// +// Only A records whose owner name is reachable from the question via the +// response's own CNAME chain are returned. An answer with an unrelated +// owner name — for example, an injected out-of-bailiwick record — is +// discarded. Without this filter, a compromised allowed zone could +// smuggle arbitrary IPs into dynamic egress rules by returning them in +// the Answer section under any name. +// +// The returned qname is normalized (lower-cased, trailing dot stripped). func ParseDNSResponse(payload []byte) (qname string, ips []net.IP, ttl uint32, err error) { var msg dns.Msg if err := msg.Unpack(payload); err != nil { return "", nil, 0, fmt.Errorf("unpack DNS response: %w", err) } if len(msg.Question) > 0 { - qname = msg.Question[0].Name + qname = normalizeDNSName(msg.Question[0].Name) + } + + // Build the set of owner names reachable from qname via CNAMEs in this + // response. Multi-pass until no new names are added; bounded by the + // Answer count so CNAME loops cannot cause infinite iteration. + validNames := map[string]struct{}{qname: {}} + for i := 0; i < len(msg.Answer); i++ { + progress := false + for _, rr := range msg.Answer { + cn, ok := rr.(*dns.CNAME) + if !ok { + continue + } + owner := normalizeDNSName(cn.Hdr.Name) + if _, have := validNames[owner]; !have { + continue + } + target := normalizeDNSName(cn.Target) + if _, already := validNames[target]; !already { + validNames[target] = struct{}{} + progress = true + } + } + if !progress { + break + } } var minTTL uint32 @@ -68,6 +114,9 @@ func ParseDNSResponse(payload []byte) (qname string, ips []net.IP, ttl uint32, e if !ok { continue } + if _, ok := validNames[normalizeDNSName(a.Hdr.Name)]; !ok { + continue + } ips = append(ips, a.A) if first || a.Hdr.Ttl < minTTL { minTTL = a.Hdr.Ttl diff --git a/net/egress/dns_test.go b/net/egress/dns_test.go index 92ea8b8..6a6b398 100644 --- a/net/egress/dns_test.go +++ b/net/egress/dns_test.go @@ -12,6 +12,31 @@ import ( "github.com/stretchr/testify/require" ) +func TestNormalizeDNSName(t *testing.T) { + t.Parallel() + + tests := []struct { + in, want string + }{ + {"example.com", "example.com"}, + {"example.com.", "example.com"}, + {"Example.COM.", "example.com"}, + {"API.GitHub.com", "api.github.com"}, + {"", ""}, + {".", ""}, + // Only a single trailing dot is stripped; double-trailing is a + // malformed name and should not be coerced. + {"foo..", "foo."}, + } + + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, normalizeDNSName(tt.in)) + }) + } +} + func TestParseDNSQuery(t *testing.T) { t.Parallel() @@ -121,13 +146,114 @@ func TestParseDNSResponse(t *testing.T) { qname, ips, ttl, err := ParseDNSResponse(payload) require.NoError(t, err) - assert.Equal(t, "example.com.", qname) + assert.Equal(t, "example.com", qname) assert.Len(t, ips, 2) assert.Equal(t, "93.184.216.34", ips[0].String()) assert.Equal(t, "93.184.216.35", ips[1].String()) assert.Equal(t, uint32(60), ttl) // minimum TTL } +func TestParseDNSResponse_DropsOutOfBailiwickAnswers(t *testing.T) { + t.Parallel() + + // Question for example.com; attacker-controlled response slips an A + // record with a different owner name (typical out-of-bailiwick injection + // attempt to smuggle an internal IP into dynamic rules). + msg := &dns.Msg{ + MsgHdr: dns.MsgHdr{Id: 0x1234, Response: true}, + Question: []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("93.184.216.34"), + }, + &dns.A{ + Hdr: dns.RR_Header{Name: "internal.local.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("169.254.169.254"), + }, + }, + } + payload, err := msg.Pack() + require.NoError(t, err) + + qname, ips, _, err := ParseDNSResponse(payload) + require.NoError(t, err) + assert.Equal(t, "example.com", qname) + require.Len(t, ips, 1) + assert.Equal(t, "93.184.216.34", ips[0].String()) +} + +func TestParseDNSResponse_FollowsCNAMEChain(t *testing.T) { + t.Parallel() + + // Legitimate CNAME chain: example.com -> cdn.example.net -> 1.2.3.4. + // The A record's owner matches the CNAME target, which is reachable + // from the question via the chain, so the IP is accepted. + msg := &dns.Msg{ + MsgHdr: dns.MsgHdr{Id: 0x2345, Response: true}, + Question: []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + Answer: []dns.RR{ + &dns.CNAME{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 300}, + Target: "cdn.example.net.", + }, + &dns.A{ + Hdr: dns.RR_Header{Name: "cdn.example.net.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("1.2.3.4"), + }, + }, + } + payload, err := msg.Pack() + require.NoError(t, err) + + qname, ips, _, err := ParseDNSResponse(payload) + require.NoError(t, err) + assert.Equal(t, "example.com", qname) + require.Len(t, ips, 1) + assert.Equal(t, "1.2.3.4", ips[0].String()) +} + +func TestParseDNSResponse_DropsUnreachableCNAMEA(t *testing.T) { + t.Parallel() + + // An A record whose owner is NOT reachable via any CNAME chain from + // the question must be dropped even if another A record from the same + // name-chain is valid. + msg := &dns.Msg{ + MsgHdr: dns.MsgHdr{Id: 0x3456, Response: true}, + Question: []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + Answer: []dns.RR{ + &dns.CNAME{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 300}, + Target: "cdn.example.net.", + }, + // A record for an unrelated name slipped into the Answer section. + &dns.A{ + Hdr: dns.RR_Header{Name: "attacker.example.net.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("10.0.0.1"), + }, + // Legitimate A record at the CNAME target. + &dns.A{ + Hdr: dns.RR_Header{Name: "cdn.example.net.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("1.2.3.4"), + }, + }, + } + payload, err := msg.Pack() + require.NoError(t, err) + + _, ips, _, err := ParseDNSResponse(payload) + require.NoError(t, err) + require.Len(t, ips, 1) + assert.Equal(t, "1.2.3.4", ips[0].String()) +} + func TestParseDNSResponse_NoARecords(t *testing.T) { t.Parallel() diff --git a/net/egress/interceptor.go b/net/egress/interceptor.go index 0dcef13..c42294f 100644 --- a/net/egress/interceptor.go +++ b/net/egress/interceptor.go @@ -16,6 +16,13 @@ const ( // if the DNS response has a shorter TTL. This prevents excessive // rule churn from very short TTLs. defaultMinTTL = 60 * time.Second + + // defaultMaxTTL is the maximum TTL applied to dynamic rules, even + // if the DNS response advertises a longer TTL. This bounds how long + // a resolved IP remains allowed — useful if an upstream server + // returns very long TTLs, and as belt-and-suspenders against a + // compromised zone returning long-lived rogue answers. + defaultMaxTTL = 5 * time.Minute ) // DNSInterceptor intercepts DNS traffic at the relay level to enforce @@ -25,21 +32,47 @@ type DNSInterceptor struct { policy *Policy dynamicRules *firewall.DynamicRules minTTL time.Duration + maxTTL time.Duration gatewayIP [4]byte } +// DNSInterceptorOption customizes a DNSInterceptor. +type DNSInterceptorOption func(*DNSInterceptor) + +// WithMinTTL sets the minimum TTL applied to dynamic rules. A zero or +// negative value leaves the default in place. +func WithMinTTL(d time.Duration) DNSInterceptorOption { + return func(i *DNSInterceptor) { + if d > 0 { + i.minTTL = d + } + } +} + +// WithMaxTTL sets the maximum TTL applied to dynamic rules. A zero or +// negative value disables the cap (any TTL is accepted). +func WithMaxTTL(d time.Duration) DNSInterceptorOption { + return func(i *DNSInterceptor) { i.maxTTL = d } +} + // NewDNSInterceptor creates an interceptor with the given policy, dynamic // rule set, and gateway IP. Only DNS responses from the gateway are // snooped to prevent spoofed responses from creating dynamic rules. -// Dynamic rules created from DNS responses will have at least minTTL -// duration (use 0 for the default of 60 seconds). -func NewDNSInterceptor(policy *Policy, dr *firewall.DynamicRules, gatewayIP [4]byte) *DNSInterceptor { - return &DNSInterceptor{ +// +// By default, dynamic rules are clamped to minTTL=60s and maxTTL=5m; +// override via WithMinTTL / WithMaxTTL. +func NewDNSInterceptor(policy *Policy, dr *firewall.DynamicRules, gatewayIP [4]byte, opts ...DNSInterceptorOption) *DNSInterceptor { + i := &DNSInterceptor{ policy: policy, dynamicRules: dr, minTTL: defaultMinTTL, + maxTTL: defaultMaxTTL, gatewayIP: gatewayIP, } + for _, o := range opts { + o(i) + } + return i } // HandleEgress processes an outbound DNS query frame. If the queried @@ -118,6 +151,9 @@ func (d *DNSInterceptor) HandleIngress(frame []byte, hdr *firewall.PacketHeader) if ttl < d.minTTL { ttl = d.minTTL } + if d.maxTTL > 0 && ttl > d.maxTTL { + ttl = d.maxTTL + } ports, proto := d.policy.HostPorts(qname) diff --git a/net/egress/interceptor_test.go b/net/egress/interceptor_test.go index 47b0a2d..c20c17f 100644 --- a/net/egress/interceptor_test.go +++ b/net/egress/interceptor_test.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "net" "testing" + "time" mdns "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -304,6 +305,37 @@ func TestDNSInterceptor_ResponseFromNonGateway_Ignored(t *testing.T) { "DNS responses from non-gateway sources must not create dynamic rules") } +func TestDNSInterceptor_ClampsTTLAtMaximum(t *testing.T) { + t.Parallel() + + policy := NewPolicy([]HostSpec{{Name: "example.com"}}) + dr := firewall.NewDynamicRules() + interceptor := NewDNSInterceptor(policy, dr, testDstIP, + WithMinTTL(1*time.Microsecond), + WithMaxTTL(5*time.Millisecond), + ) + + // Response advertises TTL = 1 hour; should be clamped to 5 ms. + ips := []net.IP{net.ParseIP("1.2.3.4")} + frame := buildDNSResponseFrame(testDstMAC, testSrcMAC, testDstIP, testSrcIP, 12345, "example.com", ips, 3600) + hdr := firewall.ParseHeaders(frame) + + interceptor.HandleIngress(frame, hdr) + + probe := &firewall.PacketHeader{ + DstIP: [4]byte{1, 2, 3, 4}, + Protocol: 6, + DstPort: 443, + } + _, ok := dr.Match(firewall.Egress, probe) + require.True(t, ok, "rule should be live immediately after ingress") + + // Give the clamp time to elapse; the rule must no longer match. + time.Sleep(25 * time.Millisecond) + _, ok = dr.Match(firewall.Egress, probe) + assert.False(t, ok, "rule should have expired past maxTTL clamp") +} + func TestDNSInterceptor_ExplicitProtocol_SingleRule(t *testing.T) { t.Parallel()