Skip to content

Commit b4ce078

Browse files
edw-defangedwardrf
authored andcommitted
Fix resolverAt race condition
1 parent 1dbacad commit b4ce078

4 files changed

Lines changed: 72 additions & 24 deletions

File tree

src/pkg/dns/check_test.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ var notFound = errors.New("not found")
1313

1414
func TestGetCNAMEInSync(t *testing.T) {
1515
t.Cleanup(func() {
16-
ResolverAt = DirectResolverAt
16+
resolverAt = DirectResolverAt
1717
})
1818

1919
notFoundResolver := MockResolver{Records: map[DNSRequest]DNSResponse{
@@ -27,7 +27,7 @@ func TestGetCNAMEInSync(t *testing.T) {
2727

2828
// Test when the domain is not found
2929
t.Run("domain not found", func(t *testing.T) {
30-
ResolverAt = func(_ string) Resolver { return notFoundResolver }
30+
resolverAt = func(_ string) Resolver { return notFoundResolver }
3131
_, err := getCNAMEInSync(t.Context(), "web.test.com")
3232
if err != notFound {
3333
t.Errorf("Expected NotFound error, got %v", err)
@@ -36,7 +36,7 @@ func TestGetCNAMEInSync(t *testing.T) {
3636

3737
// Test when the domain is found but the DNS servers are not in sync
3838
t.Run("DNS servers not in sync", func(t *testing.T) {
39-
ResolverAt = func(nsServer string) Resolver {
39+
resolverAt = func(nsServer string) Resolver {
4040
if nsServer == "ns1.example.com" {
4141
return foundResolver
4242
} else {
@@ -51,7 +51,7 @@ func TestGetCNAMEInSync(t *testing.T) {
5151

5252
// Test when the domain is found and the DNS servers are in sync
5353
t.Run("DNS servers in sync", func(t *testing.T) {
54-
ResolverAt = func(_ string) Resolver { return foundResolver }
54+
resolverAt = func(_ string) Resolver { return foundResolver }
5555
cname, err := getCNAMEInSync(t.Context(), "web.test.com")
5656
if err != nil {
5757
t.Errorf("Expected no error, got %v", err)
@@ -65,7 +65,7 @@ func TestGetCNAMEInSync(t *testing.T) {
6565

6666
func TestGetIPInSync(t *testing.T) {
6767
t.Cleanup(func() {
68-
ResolverAt = DirectResolverAt
68+
resolverAt = DirectResolverAt
6969
})
7070

7171
notFoundResolver := MockResolver{Records: map[DNSRequest]DNSResponse{
@@ -83,7 +83,7 @@ func TestGetIPInSync(t *testing.T) {
8383

8484
// Test when the domain is not found
8585
t.Run("domain not found", func(t *testing.T) {
86-
ResolverAt = func(_ string) Resolver { return notFoundResolver }
86+
resolverAt = func(_ string) Resolver { return notFoundResolver }
8787
_, err := getIPInSync(t.Context(), "test.com")
8888
if err != notFound {
8989
t.Errorf("Expected NotFound error, got %v", err)
@@ -92,7 +92,7 @@ func TestGetIPInSync(t *testing.T) {
9292

9393
// Test when the domain is found but the DNS servers are not in sync
9494
t.Run("DNS servers not in sync", func(t *testing.T) {
95-
ResolverAt = func(nsServer string) Resolver {
95+
resolverAt = func(nsServer string) Resolver {
9696
if nsServer == "ns1.example.com" {
9797
return foundResolver
9898
} else {
@@ -107,7 +107,7 @@ func TestGetIPInSync(t *testing.T) {
107107

108108
// 2nd not in sync scenario
109109
t.Run("DNS servers not in sync with partial results", func(t *testing.T) {
110-
ResolverAt = func(nsServer string) Resolver {
110+
resolverAt = func(nsServer string) Resolver {
111111
if nsServer == "ns1.example.com" {
112112
return partialFoundResolver
113113
} else {
@@ -122,7 +122,7 @@ func TestGetIPInSync(t *testing.T) {
122122

123123
// Test when the domain is found and the DNS servers are in sync
124124
t.Run("DNS servers in sync", func(t *testing.T) {
125-
ResolverAt = func(_ string) Resolver { return foundResolver }
125+
resolverAt = func(_ string) Resolver { return foundResolver }
126126
ips, err := getIPInSync(t.Context(), "test.com")
127127
if err != nil {
128128
t.Errorf("Expected no error, got %v", err)
@@ -153,42 +153,42 @@ func TestCheckDomainDNSReady(t *testing.T) {
153153
}}
154154
resolver = hasARecordResolver
155155

156-
oldResolver, oldDebug := ResolverAt, term.DoDebug()
156+
oldResolver, oldDebug := resolverAt, term.DoDebug()
157157
t.Cleanup(func() {
158-
ResolverAt = oldResolver
158+
resolverAt = oldResolver
159159
term.SetDebug(oldDebug)
160160
})
161161
term.SetDebug(true)
162162

163163
t.Run("CNAME and A records not found", func(t *testing.T) {
164-
ResolverAt = func(_ string) Resolver { return emptyResolver }
164+
resolverAt = func(_ string) Resolver { return emptyResolver }
165165
if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != false {
166166
t.Errorf("Expected false when both CNAME and A records are missing, got true")
167167
}
168168
})
169169

170170
t.Run("CNAME setup correctly", func(t *testing.T) {
171-
ResolverAt = func(_ string) Resolver { return hasCNAMEResolver }
171+
resolverAt = func(_ string) Resolver { return hasCNAMEResolver }
172172
if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != true {
173173
t.Errorf("Expected true when CNAME is setup correctly, got false")
174174
}
175175
})
176176

177177
t.Run("CNAME setup incorrectly", func(t *testing.T) {
178-
ResolverAt = func(_ string) Resolver { return hasCNAMEResolver }
178+
resolverAt = func(_ string) Resolver { return hasCNAMEResolver }
179179
if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-other-alb.domain.com"}) != false {
180180
t.Errorf("Expected false when CNAME is setup incorrectly, got true")
181181
}
182182
})
183183

184184
t.Run("A record setup correctly", func(t *testing.T) {
185-
ResolverAt = func(_ string) Resolver { return hasARecordResolver }
185+
resolverAt = func(_ string) Resolver { return hasARecordResolver }
186186
if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != true {
187187
t.Errorf("Expected true when A record is setup correctly, got false")
188188
}
189189
})
190190
t.Run("A record setup incorrectly", func(t *testing.T) {
191-
ResolverAt = func(_ string) Resolver { return hasWrongARecordResolver }
191+
resolverAt = func(_ string) Resolver { return hasWrongARecordResolver }
192192
if CheckDomainDNSReady(t.Context(), "api.test.com", []string{"some-alb.domain.com"}) != false {
193193
t.Errorf("Expected false when A record is setup incorrectly, got true")
194194
}

src/pkg/dns/fabric_test.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package dns
33
import (
44
"context"
55
"errors"
6+
"sync"
67
"testing"
78

89
defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1"
@@ -109,7 +110,7 @@ func TestFabricResolverLookupNS(t *testing.T) {
109110
func TestUseFabricResolver(t *testing.T) {
110111
t.Cleanup(func() {
111112
fabricClient = nil
112-
ResolverAt = DirectResolverAt
113+
resolverAt = DirectResolverAt
113114
})
114115

115116
m := &mockFabricClient{ipResp: &defangv1.ResolveIPAddrResponse{IpAddrs: []string{"9.9.9.9"}}}
@@ -130,3 +131,38 @@ func TestUseFabricResolver(t *testing.T) {
130131
t.Errorf("ResolverAt did not return FabricResolver: %T %+v", r, r)
131132
}
132133
}
134+
135+
// TestResolverAtConcurrentWithUseFabricResolver exercises the synchronization
136+
// between UseFabricResolver (which swaps resolverAt) and ResolverAt callers.
137+
// Run with `go test -race` — prior to the mutex-guarded ResolverAt, concurrent
138+
// writes and reads on the package-level variable were a data race.
139+
func TestResolverAtConcurrentWithUseFabricResolver(t *testing.T) {
140+
t.Cleanup(func() {
141+
fabricClient = nil
142+
resolverAt = DirectResolverAt
143+
})
144+
145+
m := &mockFabricClient{nsResp: &defangv1.ResolveNSResponse{Hosts: []string{"ns1.example.com."}}}
146+
147+
var wg sync.WaitGroup
148+
stop := make(chan struct{})
149+
for range 4 {
150+
wg.Add(1)
151+
go func() {
152+
defer wg.Done()
153+
for {
154+
select {
155+
case <-stop:
156+
return
157+
default:
158+
_ = ResolverAt("ns.example.com")
159+
}
160+
}
161+
}()
162+
}
163+
for range 200 {
164+
UseFabricResolver(m)
165+
}
166+
close(stop)
167+
wg.Wait()
168+
}

src/pkg/dns/resolver.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ type FabricResolverClient interface {
2828
ResolveNS(context.Context, *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error)
2929
}
3030

31-
// fabricMu guards concurrent access to fabricClient and the ResolverAt
32-
// assignment inside UseFabricResolver.
31+
// fabricMu guards concurrent access to fabricClient and resolverAt.
3332
var fabricMu sync.RWMutex
3433

3534
// fabricClient is set by UseFabricResolver. When non-nil, RootResolver and
@@ -43,7 +42,7 @@ func UseFabricResolver(c FabricResolverClient) {
4342
fabricMu.Lock()
4443
defer fabricMu.Unlock()
4544
fabricClient = c
46-
ResolverAt = func(nsServer string) Resolver {
45+
resolverAt = func(nsServer string) Resolver {
4746
return FabricResolver{Client: c, NSServer: nsServer}
4847
}
4948
}
@@ -195,7 +194,20 @@ func DirectResolverAt(nsServer string) Resolver {
195194
return DirectResolver{NSServer: nsServer}
196195
}
197196

198-
var ResolverAt = DirectResolverAt
197+
// resolverAt is the package-private function that produces a Resolver bound to
198+
// a given nameserver. It is swapped out by UseFabricResolver. All reads must go
199+
// through ResolverAt so they're synchronized with that write.
200+
var resolverAt = DirectResolverAt
201+
202+
// ResolverAt returns a Resolver bound to nsServer. When UseFabricResolver has
203+
// wired in a fabric client, the returned Resolver issues remote RPCs;
204+
// otherwise it performs direct UDP DNS queries.
205+
func ResolverAt(nsServer string) Resolver {
206+
fabricMu.RLock()
207+
fn := resolverAt
208+
fabricMu.RUnlock()
209+
return fn(nsServer)
210+
}
199211

200212
var ErrNoSuchHost = &net.DNSError{Err: "no such host", IsNotFound: true}
201213

src/pkg/dns/resolver_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ import (
77

88
func TestFindNSServer(t *testing.T) {
99
t.Cleanup(func() {
10-
ResolverAt = DirectResolverAt
10+
resolverAt = DirectResolverAt
1111
})
1212

1313
t.Run("NS server not exist on domain", func(t *testing.T) {
14-
ResolverAt = func(nsServer string) Resolver {
14+
resolverAt = func(nsServer string) Resolver {
1515
if strings.Contains(nsServer, "root-servers.net") {
1616
return MockResolver{Records: map[DNSRequest]DNSResponse{
1717
{Type: "NS", Domain: "a.b.c.d"}: {Records: []string{"1.tld-servers.com", "2.tld-servers.com"}, Error: nil},
@@ -42,7 +42,7 @@ func TestFindNSServer(t *testing.T) {
4242
})
4343

4444
t.Run("NS server exist on domain (delegarted apex domain)", func(t *testing.T) {
45-
ResolverAt = func(nsServer string) Resolver {
45+
resolverAt = func(nsServer string) Resolver {
4646
if strings.Contains(nsServer, "root-servers.net") {
4747
return MockResolver{Records: map[DNSRequest]DNSResponse{
4848
{Type: "NS", Domain: "a.b.c.d"}: {Records: []string{"1.tld-servers.com", "2.tld-servers.com"}, Error: nil},

0 commit comments

Comments
 (0)