Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion net/egress/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,21 @@ package egress
import (
"fmt"
"net"
"strings"

"github.com/miekg/dns"
)

// normalizeDNSName returns s lower-cased with a single trailing dot stripped,
// so that names from DNS wire format (often FQDN with trailing dot) compare
// consistently against caller-supplied policy names (typically without a
// trailing dot). DNS names are case-insensitive by RFC 1035.
func normalizeDNSName(s string) string {
s = strings.ToLower(s)
s = strings.TrimSuffix(s, ".")
return s
}

// ParseDNSQuery extracts the transaction ID, question name, and query type
// from a DNS query payload (UDP payload, not including Ethernet/IP/UDP headers).
func ParseDNSQuery(payload []byte) (txnID uint16, qname string, qtype uint16, err error) {
Expand Down Expand Up @@ -52,13 +63,48 @@ func BuildNXDOMAIN(txnID uint16, qname string, qtype uint16) ([]byte, error) {

// ParseDNSResponse extracts the question name, A-record IPs, and minimum
// TTL from a DNS response payload.
//
// Only A records whose owner name is reachable from the question via the
// response's own CNAME chain are returned. An answer with an unrelated
// owner name — for example, an injected out-of-bailiwick record — is
// discarded. Without this filter, a compromised allowed zone could
// smuggle arbitrary IPs into dynamic egress rules by returning them in
// the Answer section under any name.
//
// The returned qname is normalized (lower-cased, trailing dot stripped).
func ParseDNSResponse(payload []byte) (qname string, ips []net.IP, ttl uint32, err error) {
var msg dns.Msg
if err := msg.Unpack(payload); err != nil {
return "", nil, 0, fmt.Errorf("unpack DNS response: %w", err)
}
if len(msg.Question) > 0 {
qname = msg.Question[0].Name
qname = normalizeDNSName(msg.Question[0].Name)
}

// Build the set of owner names reachable from qname via CNAMEs in this
// response. Multi-pass until no new names are added; bounded by the
// Answer count so CNAME loops cannot cause infinite iteration.
validNames := map[string]struct{}{qname: {}}
for i := 0; i < len(msg.Answer); i++ {
progress := false
for _, rr := range msg.Answer {
cn, ok := rr.(*dns.CNAME)
if !ok {
continue
}
owner := normalizeDNSName(cn.Hdr.Name)
if _, have := validNames[owner]; !have {
continue
}
target := normalizeDNSName(cn.Target)
if _, already := validNames[target]; !already {
validNames[target] = struct{}{}
progress = true
}
}
if !progress {
break
}
}

var minTTL uint32
Expand All @@ -68,6 +114,9 @@ func ParseDNSResponse(payload []byte) (qname string, ips []net.IP, ttl uint32, e
if !ok {
continue
}
if _, ok := validNames[normalizeDNSName(a.Hdr.Name)]; !ok {
continue
}
ips = append(ips, a.A)
if first || a.Hdr.Ttl < minTTL {
minTTL = a.Hdr.Ttl
Expand Down
128 changes: 127 additions & 1 deletion net/egress/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@ import (
"github.com/stretchr/testify/require"
)

func TestNormalizeDNSName(t *testing.T) {
t.Parallel()

tests := []struct {
in, want string
}{
{"example.com", "example.com"},
{"example.com.", "example.com"},
{"Example.COM.", "example.com"},
{"API.GitHub.com", "api.github.com"},
{"", ""},
{".", ""},
// Only a single trailing dot is stripped; double-trailing is a
// malformed name and should not be coerced.
{"foo..", "foo."},
}

for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.want, normalizeDNSName(tt.in))
})
}
}

func TestParseDNSQuery(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -121,13 +146,114 @@ func TestParseDNSResponse(t *testing.T) {

qname, ips, ttl, err := ParseDNSResponse(payload)
require.NoError(t, err)
assert.Equal(t, "example.com.", qname)
assert.Equal(t, "example.com", qname)
assert.Len(t, ips, 2)
assert.Equal(t, "93.184.216.34", ips[0].String())
assert.Equal(t, "93.184.216.35", ips[1].String())
assert.Equal(t, uint32(60), ttl) // minimum TTL
}

func TestParseDNSResponse_DropsOutOfBailiwickAnswers(t *testing.T) {
t.Parallel()

// Question for example.com; attacker-controlled response slips an A
// record with a different owner name (typical out-of-bailiwick injection
// attempt to smuggle an internal IP into dynamic rules).
msg := &dns.Msg{
MsgHdr: dns.MsgHdr{Id: 0x1234, Response: true},
Question: []dns.Question{
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300},
A: net.ParseIP("93.184.216.34"),
},
&dns.A{
Hdr: dns.RR_Header{Name: "internal.local.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300},
A: net.ParseIP("169.254.169.254"),
},
},
}
payload, err := msg.Pack()
require.NoError(t, err)

qname, ips, _, err := ParseDNSResponse(payload)
require.NoError(t, err)
assert.Equal(t, "example.com", qname)
require.Len(t, ips, 1)
assert.Equal(t, "93.184.216.34", ips[0].String())
}

func TestParseDNSResponse_FollowsCNAMEChain(t *testing.T) {
t.Parallel()

// Legitimate CNAME chain: example.com -> cdn.example.net -> 1.2.3.4.
// The A record's owner matches the CNAME target, which is reachable
// from the question via the chain, so the IP is accepted.
msg := &dns.Msg{
MsgHdr: dns.MsgHdr{Id: 0x2345, Response: true},
Question: []dns.Question{
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
Answer: []dns.RR{
&dns.CNAME{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 300},
Target: "cdn.example.net.",
},
&dns.A{
Hdr: dns.RR_Header{Name: "cdn.example.net.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300},
A: net.ParseIP("1.2.3.4"),
},
},
}
payload, err := msg.Pack()
require.NoError(t, err)

qname, ips, _, err := ParseDNSResponse(payload)
require.NoError(t, err)
assert.Equal(t, "example.com", qname)
require.Len(t, ips, 1)
assert.Equal(t, "1.2.3.4", ips[0].String())
}

func TestParseDNSResponse_DropsUnreachableCNAMEA(t *testing.T) {
t.Parallel()

// An A record whose owner is NOT reachable via any CNAME chain from
// the question must be dropped even if another A record from the same
// name-chain is valid.
msg := &dns.Msg{
MsgHdr: dns.MsgHdr{Id: 0x3456, Response: true},
Question: []dns.Question{
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
},
Answer: []dns.RR{
&dns.CNAME{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: 300},
Target: "cdn.example.net.",
},
// A record for an unrelated name slipped into the Answer section.
&dns.A{
Hdr: dns.RR_Header{Name: "attacker.example.net.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300},
A: net.ParseIP("10.0.0.1"),
},
// Legitimate A record at the CNAME target.
&dns.A{
Hdr: dns.RR_Header{Name: "cdn.example.net.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300},
A: net.ParseIP("1.2.3.4"),
},
},
}
payload, err := msg.Pack()
require.NoError(t, err)

_, ips, _, err := ParseDNSResponse(payload)
require.NoError(t, err)
require.Len(t, ips, 1)
assert.Equal(t, "1.2.3.4", ips[0].String())
}

func TestParseDNSResponse_NoARecords(t *testing.T) {
t.Parallel()

Expand Down
44 changes: 40 additions & 4 deletions net/egress/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ const (
// if the DNS response has a shorter TTL. This prevents excessive
// rule churn from very short TTLs.
defaultMinTTL = 60 * time.Second

// defaultMaxTTL is the maximum TTL applied to dynamic rules, even
// if the DNS response advertises a longer TTL. This bounds how long
// a resolved IP remains allowed — useful if an upstream server
// returns very long TTLs, and as belt-and-suspenders against a
// compromised zone returning long-lived rogue answers.
defaultMaxTTL = 5 * time.Minute
)

// DNSInterceptor intercepts DNS traffic at the relay level to enforce
Expand All @@ -25,21 +32,47 @@ type DNSInterceptor struct {
policy *Policy
dynamicRules *firewall.DynamicRules
minTTL time.Duration
maxTTL time.Duration
gatewayIP [4]byte
}

// DNSInterceptorOption customizes a DNSInterceptor.
type DNSInterceptorOption func(*DNSInterceptor)

// WithMinTTL sets the minimum TTL applied to dynamic rules. A zero or
// negative value leaves the default in place.
func WithMinTTL(d time.Duration) DNSInterceptorOption {
return func(i *DNSInterceptor) {
if d > 0 {
i.minTTL = d
}
}
}

// WithMaxTTL sets the maximum TTL applied to dynamic rules. A zero or
// negative value disables the cap (any TTL is accepted).
func WithMaxTTL(d time.Duration) DNSInterceptorOption {
return func(i *DNSInterceptor) { i.maxTTL = d }
}

// NewDNSInterceptor creates an interceptor with the given policy, dynamic
// rule set, and gateway IP. Only DNS responses from the gateway are
// snooped to prevent spoofed responses from creating dynamic rules.
// Dynamic rules created from DNS responses will have at least minTTL
// duration (use 0 for the default of 60 seconds).
func NewDNSInterceptor(policy *Policy, dr *firewall.DynamicRules, gatewayIP [4]byte) *DNSInterceptor {
return &DNSInterceptor{
//
// By default, dynamic rules are clamped to minTTL=60s and maxTTL=5m;
// override via WithMinTTL / WithMaxTTL.
func NewDNSInterceptor(policy *Policy, dr *firewall.DynamicRules, gatewayIP [4]byte, opts ...DNSInterceptorOption) *DNSInterceptor {
i := &DNSInterceptor{
policy: policy,
dynamicRules: dr,
minTTL: defaultMinTTL,
maxTTL: defaultMaxTTL,
gatewayIP: gatewayIP,
}
for _, o := range opts {
o(i)
}
return i
}

// HandleEgress processes an outbound DNS query frame. If the queried
Expand Down Expand Up @@ -118,6 +151,9 @@ func (d *DNSInterceptor) HandleIngress(frame []byte, hdr *firewall.PacketHeader)
if ttl < d.minTTL {
ttl = d.minTTL
}
if d.maxTTL > 0 && ttl > d.maxTTL {
ttl = d.maxTTL
}

ports, proto := d.policy.HostPorts(qname)

Expand Down
32 changes: 32 additions & 0 deletions net/egress/interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/binary"
"net"
"testing"
"time"

mdns "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -304,6 +305,37 @@ func TestDNSInterceptor_ResponseFromNonGateway_Ignored(t *testing.T) {
"DNS responses from non-gateway sources must not create dynamic rules")
}

func TestDNSInterceptor_ClampsTTLAtMaximum(t *testing.T) {
t.Parallel()

policy := NewPolicy([]HostSpec{{Name: "example.com"}})
dr := firewall.NewDynamicRules()
interceptor := NewDNSInterceptor(policy, dr, testDstIP,
WithMinTTL(1*time.Microsecond),
WithMaxTTL(5*time.Millisecond),
)

// Response advertises TTL = 1 hour; should be clamped to 5 ms.
ips := []net.IP{net.ParseIP("1.2.3.4")}
frame := buildDNSResponseFrame(testDstMAC, testSrcMAC, testDstIP, testSrcIP, 12345, "example.com", ips, 3600)
hdr := firewall.ParseHeaders(frame)

interceptor.HandleIngress(frame, hdr)

probe := &firewall.PacketHeader{
DstIP: [4]byte{1, 2, 3, 4},
Protocol: 6,
DstPort: 443,
}
_, ok := dr.Match(firewall.Egress, probe)
require.True(t, ok, "rule should be live immediately after ingress")

// Give the clamp time to elapse; the rule must no longer match.
time.Sleep(25 * time.Millisecond)
_, ok = dr.Match(firewall.Egress, probe)
assert.False(t, ok, "rule should have expired past maxTTL clamp")
}

func TestDNSInterceptor_ExplicitProtocol_SingleRule(t *testing.T) {
t.Parallel()

Expand Down
Loading