Skip to content

Commit 9f98c96

Browse files
authored
Preserve DNS provider record data for cleanup (#382)
Co-authored-by: bjornmp <bjornmp@users.noreply.github.com>
1 parent 49b3509 commit 9f98c96

2 files changed

Lines changed: 76 additions & 13 deletions

File tree

solvers.go

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ func (m *DNSManager) createRecord(ctx context.Context, dnsName, recordType, reco
410410
return zoneRecord{}, fmt.Errorf("expected one record, got %d: %v", len(results), results)
411411
}
412412

413-
return zoneRecord{zone, results[0].RR()}, nil
413+
return zoneRecord{zone, results[0]}, nil
414414
}
415415

416416
// wait blocks until the TXT record created in Present() appears in
@@ -445,12 +445,13 @@ func (m *DNSManager) wait(ctx context.Context, zrec zoneRecord) error {
445445
checkAuthoritativeServers := len(m.Resolvers) == 0
446446
resolvers := RecursiveNameservers(m.Resolvers)
447447

448+
rr := zrec.record.RR()
448449
recType := dns.TypeTXT
449-
if zrec.record.RR().Type == "CNAME" {
450+
if rr.Type == "CNAME" {
450451
recType = dns.TypeCNAME
451452
}
452453

453-
absName := libdns.AbsoluteName(zrec.record.Name, zrec.zone)
454+
absName := libdns.AbsoluteName(rr.Name, zrec.zone)
454455

455456
var err error
456457
start := time.Now()
@@ -463,14 +464,14 @@ func (m *DNSManager) wait(ctx context.Context, zrec zoneRecord) error {
463464

464465
logger.Debug("checking DNS propagation",
465466
zap.String("fqdn", absName),
466-
zap.String("record_type", zrec.record.Type),
467-
zap.String("expected_data", zrec.record.Data),
467+
zap.String("record_type", rr.Type),
468+
zap.String("expected_data", rr.Data),
468469
zap.Strings("resolvers", resolvers))
469470

470471
var ready bool
471-
ready, err = checkDNSPropagation(ctx, logger, absName, recType, zrec.record.Data, checkAuthoritativeServers, resolvers)
472+
ready, err = checkDNSPropagation(ctx, logger, absName, recType, rr.Data, checkAuthoritativeServers, resolvers)
472473
if err != nil {
473-
return fmt.Errorf("checking DNS propagation of %q (relative=%s zone=%s resolvers=%v): %w", absName, zrec.record.Name, zrec.zone, resolvers, err)
474+
return fmt.Errorf("checking DNS propagation of %q (relative=%s zone=%s resolvers=%v): %w", absName, rr.Name, zrec.zone, resolvers, err)
474475
}
475476
if ready {
476477
return nil
@@ -482,7 +483,7 @@ func (m *DNSManager) wait(ctx context.Context, zrec zoneRecord) error {
482483

483484
type zoneRecord struct {
484485
zone string
485-
record libdns.RR
486+
record libdns.Record
486487
}
487488

488489
// CleanUp deletes the DNS TXT record created in Present().
@@ -506,11 +507,12 @@ func (m *DNSManager) cleanUpRecord(_ context.Context, zrec zoneRecord) error {
506507
ctx, cancel := context.WithTimeout(context.Background(), timeout)
507508
defer cancel()
508509

510+
rr := zrec.record.RR()
509511
logger.Debug("deleting DNS record",
510512
zap.String("zone", zrec.zone),
511-
zap.String("record_name", zrec.record.Name),
512-
zap.String("record_type", zrec.record.Type),
513-
zap.String("record_data", zrec.record.Data))
513+
zap.String("record_name", rr.Name),
514+
zap.String("record_type", rr.Type),
515+
zap.String("record_data", rr.Data))
514516

515517
_, err := m.DNSProvider.DeleteRecords(ctx, zrec.zone, []libdns.Record{zrec.record})
516518
if err != nil {
@@ -552,7 +554,8 @@ func (s *DNSManager) getDNSPresentMemory(dnsName, recType, value string) (dnsPre
552554
var memory dnsPresentMemory
553555
var found bool
554556
for _, mem := range s.records[dnsName] {
555-
if mem.zoneRec.record.Type == recType && mem.zoneRec.record.Data == value {
557+
rr := mem.zoneRec.record.RR()
558+
if rr.Type == recType && rr.Data == value {
556559
memory = mem
557560
found = true
558561
break
@@ -570,7 +573,7 @@ func (s *DNSManager) deleteDNSPresentMemory(dnsName, keyAuth string) {
570573
defer s.recordsMu.Unlock()
571574

572575
for i, mem := range s.records[dnsName] {
573-
if mem.zoneRec.record.Data == keyAuth {
576+
if mem.zoneRec.record.RR().Data == keyAuth {
574577
s.records[dnsName] = append(s.records[dnsName][:i], s.records[dnsName][i+1:]...)
575578
return
576579
}

solvers_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
package certmagic
1616

1717
import (
18+
"context"
1819
"net"
1920
"strconv"
2021
"testing"
22+
"time"
2123

2224
"github.com/caddyserver/certmagic/internal/filedescriptor"
25+
"github.com/libdns/libdns"
2326
"github.com/mholt/acmez/v3/acme"
2427
)
2528

@@ -159,6 +162,63 @@ func Test_challengeKey(t *testing.T) {
159162
}
160163
}
161164

165+
func TestDNSManagerCleanUpRecordPreservesProviderData(t *testing.T) {
166+
provider := &providerDataDeleteProvider{t: t}
167+
manager := DNSManager{
168+
DNSProvider: provider,
169+
PropagationTimeout: time.Second,
170+
}
171+
172+
err := manager.cleanUpRecord(context.Background(), zoneRecord{
173+
zone: "example.com.",
174+
record: libdns.TXT{
175+
Name: "_acme-challenge",
176+
Text: "token",
177+
TTL: time.Minute,
178+
ProviderData: map[string]string{
179+
"id": "123",
180+
},
181+
},
182+
})
183+
if err != nil {
184+
t.Fatalf("cleanup failed: %v", err)
185+
}
186+
if !provider.deleted {
187+
t.Fatal("expected DeleteRecords to be called")
188+
}
189+
}
190+
191+
type providerDataDeleteProvider struct {
192+
t *testing.T
193+
deleted bool
194+
}
195+
196+
func (p *providerDataDeleteProvider) AppendRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
197+
return records, nil
198+
}
199+
200+
func (p *providerDataDeleteProvider) DeleteRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
201+
p.deleted = true
202+
if zone != "example.com." {
203+
p.t.Fatalf("expected zone example.com., got %q", zone)
204+
}
205+
if len(records) != 1 {
206+
p.t.Fatalf("expected 1 record, got %d", len(records))
207+
}
208+
txt, ok := records[0].(libdns.TXT)
209+
if !ok {
210+
p.t.Fatalf("expected libdns.TXT with provider data, got %T", records[0])
211+
}
212+
pd, ok := txt.ProviderData.(map[string]string)
213+
if !ok {
214+
p.t.Fatalf("expected ProviderData map, got %T", txt.ProviderData)
215+
}
216+
if pd["id"] != "123" {
217+
p.t.Fatalf("expected provider ID 123, got %q", pd["id"])
218+
}
219+
return records, nil
220+
}
221+
162222
func TestGetACMEChallenge_IPv6Brackets(t *testing.T) {
163223
// Store a challenge under a bare IPv6 identifier (as CertMagic does internally).
164224
bare := "::1"

0 commit comments

Comments
 (0)