Skip to content

Commit 1dbacad

Browse files
edw-defangedwardrf
authored andcommitted
Fix fabric dns client race condition
1 parent a02e5b3 commit 1dbacad

1 file changed

Lines changed: 19 additions & 6 deletions

File tree

src/pkg/dns/resolver.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net"
88
"slices"
99
"sort"
10+
"sync"
1011

1112
"github.com/DefangLabs/defang/src/pkg"
1213
defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1"
@@ -27,6 +28,10 @@ type FabricResolverClient interface {
2728
ResolveNS(context.Context, *defangv1.ResolveNSRequest) (*defangv1.ResolveNSResponse, error)
2829
}
2930

31+
// fabricMu guards concurrent access to fabricClient and the ResolverAt
32+
// assignment inside UseFabricResolver.
33+
var fabricMu sync.RWMutex
34+
3035
// fabricClient is set by UseFabricResolver. When non-nil, RootResolver and
3136
// ResolverAt route DNS lookups through the fabric gRPC API.
3237
var fabricClient FabricResolverClient
@@ -35,12 +40,20 @@ var fabricClient FabricResolverClient
3540
// called, RootResolver{} and ResolverAt(nsServer) both issue remote RPCs
3641
// instead of performing direct UDP DNS queries.
3742
func UseFabricResolver(c FabricResolverClient) {
43+
fabricMu.Lock()
44+
defer fabricMu.Unlock()
3845
fabricClient = c
3946
ResolverAt = func(nsServer string) Resolver {
4047
return FabricResolver{Client: c, NSServer: nsServer}
4148
}
4249
}
4350

51+
func getFabricClient() FabricResolverClient {
52+
fabricMu.RLock()
53+
defer fabricMu.RUnlock()
54+
return fabricClient
55+
}
56+
4457
// FabricResolver performs DNS lookups via the fabric gRPC API. An empty
4558
// NSServer lets the server perform recursive resolution from the root.
4659
type FabricResolver struct {
@@ -117,8 +130,8 @@ var rootServers = []*net.NS{
117130
}
118131

119132
func (r RootResolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IPAddr, error) {
120-
if fabricClient != nil {
121-
return FabricResolver{Client: fabricClient}.LookupIPAddr(ctx, domain)
133+
if c := getFabricClient(); c != nil {
134+
return FabricResolver{Client: c}.LookupIPAddr(ctx, domain)
122135
}
123136
for range 10 {
124137
ips, err := r.getResolver(ctx, domain).LookupIPAddr(ctx, domain)
@@ -136,15 +149,15 @@ func (r RootResolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IP
136149
}
137150

138151
func (r RootResolver) LookupCNAME(ctx context.Context, domain string) (string, error) {
139-
if fabricClient != nil {
140-
return FabricResolver{Client: fabricClient}.LookupCNAME(ctx, domain)
152+
if c := getFabricClient(); c != nil {
153+
return FabricResolver{Client: c}.LookupCNAME(ctx, domain)
141154
}
142155
return r.getResolver(ctx, domain).LookupCNAME(ctx, domain)
143156
}
144157

145158
func (r RootResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
146-
if fabricClient != nil {
147-
return FabricResolver{Client: fabricClient}.LookupNS(ctx, domain)
159+
if c := getFabricClient(); c != nil {
160+
return FabricResolver{Client: c}.LookupNS(ctx, domain)
148161
}
149162
return r.getResolver(ctx, domain).LookupNS(ctx, domain)
150163
}

0 commit comments

Comments
 (0)