Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,29 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) (result []ne
}

func (r *Resolver) LookupTXT(ctx context.Context, domain string) ([]string, error) {
result, ok := r.getCachedTXT(domain)
if ok {
return result, nil
result, _, err := r.LookupTXTWithTTL(ctx, domain)
return result, err
}

// LookupTXTWithTTL is like [Resolver.LookupTXT] but also returns how long the
// TXT records may be cached. The TTL is the smallest Ttl across the answer's
// TXT resource records, capped by the resolver's max cache TTL. On a cache hit
// it is the remaining lifetime of the cached entry, so the value shrinks as the
// entry ages. A TTL of 0 means the TTL is unknown, for example when the
// upstream resolver does not provide one.
func (r *Resolver) LookupTXTWithTTL(ctx context.Context, domain string) ([]string, time.Duration, error) {
if result, ttl, ok := r.getCachedTXTWithTTL(domain); ok {
return result, ttl, nil
}

result, ttl, err := doRequestTXT(ctx, r.url, domain)
if err != nil {
return nil, err
return nil, 0, err
}

cacheTTL := minTTL(time.Duration(ttl)*time.Second, r.maxCacheTTL)
r.cacheTXT(domain, result, cacheTTL)
return result, nil
return result, cacheTTL, nil
}

func (r *Resolver) getCachedIPAddr(domain string) ([]net.IPAddr, bool) {
Expand Down Expand Up @@ -170,21 +180,26 @@ func (r *Resolver) cacheIPAddr(domain string, ips []net.IPAddr, ttl time.Duratio
}

func (r *Resolver) getCachedTXT(domain string) ([]string, bool) {
txt, _, ok := r.getCachedTXTWithTTL(domain)
return txt, ok
}

func (r *Resolver) getCachedTXTWithTTL(domain string) ([]string, time.Duration, bool) {
r.mx.Lock()
defer r.mx.Unlock()

fqdn := dns.Fqdn(domain)
entry, ok := r.txtCache[fqdn]
if !ok {
return nil, false
return nil, 0, false
}

if time.Now().After(entry.expire) {
delete(r.txtCache, fqdn)
return nil, false
return nil, 0, false
}

return entry.txt, true
return entry.txt, time.Until(entry.expire), true
}

func (r *Resolver) cacheTXT(domain string, txt []string, ttl time.Duration) {
Expand Down
56 changes: 56 additions & 0 deletions resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,62 @@ func TestLookupTXT(t *testing.T) {
}
}

func TestLookupTXTWithTTL(t *testing.T) {
domain := "example.com"
resolver := mockDoHResolver(t, map[uint16]*dns.Msg{
dns.TypeTXT: mockDNSAnswerTXT(dns.Fqdn(domain), []string{"dnslink=/ipns/example.com"}),
})
defer resolver.Close()

r, err := NewResolver(resolver.URL)
if err != nil {
t.Fatal("resolver cannot be initialised")
}

// cold lookup returns the record TTL from the answer (300s in the mock)
txt, ttl, err := r.LookupTXTWithTTL(context.Background(), domain)
if err != nil {
t.Fatal(err)
}
if len(txt) == 0 {
t.Fatal("got no TXT entries")
}
if ttl != 300*time.Second {
t.Fatalf("expected ttl 300s, got %s", ttl)
}

// warm lookup (cache hit) returns the remaining TTL, never more than the record TTL
_, ttl2, err := r.LookupTXTWithTTL(context.Background(), domain)
if err != nil {
t.Fatal(err)
}
if ttl2 <= 0 || ttl2 > 300*time.Second {
t.Fatalf("expected remaining ttl in (0s, 300s], got %s", ttl2)
}
}

func TestLookupTXTWithTTLCappedByMaxCacheTTL(t *testing.T) {
domain := "example.com"
resolver := mockDoHResolver(t, map[uint16]*dns.Msg{
dns.TypeTXT: mockDNSAnswerTXT(dns.Fqdn(domain), []string{"dnslink=/ipns/example.com"}),
})
defer resolver.Close()

// record TTL (300s) is larger than the max cache TTL, so the returned TTL is capped
r, err := NewResolver(resolver.URL, WithMaxCacheTTL(10*time.Second))
if err != nil {
t.Fatal("resolver cannot be initialised")
}

_, ttl, err := r.LookupTXTWithTTL(context.Background(), domain)
if err != nil {
t.Fatal(err)
}
if ttl != 10*time.Second {
t.Fatalf("expected ttl capped to 10s, got %s", ttl)
}
}

func TestLookupCache(t *testing.T) {
domain := "example.com"
resolver := mockDoHResolver(t, map[uint16]*dns.Msg{
Expand Down
Loading