Skip to content

Commit aa68b58

Browse files
reyortiz3claude
andauthored
Fix GetDefaultProvider bypassing registered config factory (#4740)
* Fix GetDefaultProvider bypassing registered config factory config.NewDefaultProvider() skips the RegisterProviderFactory hook and always reads the local XDG config file. Switch to config.NewProvider(), which checks the registered factory first and falls back to the default only when no factory is registered. Without this fix, enterprise builds that register a factory to supply an EnterpriseProvider (e.g. fetching config from a remote config server) were silently ignored, causing thv registry commands to use the embedded registry instead of the configured one. Add unit tests covering the factory-respected and fall-through cases, caching semantics, and reset behaviour. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Fix NewRegistryRoutes constructors bypassing registered config factory Both NewRegistryRoutes and NewRegistryRoutesForServe called config.NewDefaultProvider(), which skips the RegisterProviderFactory hook and always reads the local XDG config file. Switch to config.NewProvider() so enterprise builds that register a factory (e.g. to supply a remote config server) are correctly honoured. Add unit tests covering the regression case and the no-factory fallback for both constructors. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Move config.NewProvider() inside sync.Once closure in GetDefaultProvider Go evaluates function arguments eagerly, so the previous implementation called config.NewProvider() (and thus any registered ProviderFactory) on every invocation of GetDefaultProvider, even after sync.Once had already fired and the resulting provider was discarded. Move the call inside the Do closure so the factory is invoked at most once. GetDefaultProvider now owns its sync.Once block directly rather than delegating to GetDefaultProviderWithConfig. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Fix data race between ResetDefaultProvider and GetDefaultProviderWithConfig ResetDefaultProvider assigned defaultProviderOnce = sync.Once{} (a non-atomic struct write) while GetDefaultProviderWithConfig called defaultProviderOnce.Do(), which does an atomic load of the internal done field. Mixed-width concurrent access = data race. Replace the three separate package-level variables (sync.Once, Provider, error) with a single providerState struct stored behind an atomic.Pointer. ResetDefaultProvider now atomically swaps in a fresh struct instead of writing to one that may be in use. Goroutines that loaded a pointer before a reset keep a stable reference and complete safely; goroutines that load after the swap initialise against the new state. The mutex is no longer needed and is removed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Add nolint:paralleltest directives to factory tests These tests mutate global singletons (config.registeredFactory and currentProviderState) so they cannot run in parallel. Suppress the paralleltest linter warning with an explanatory comment on each test. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Fix gci formatting: move nolint directives after doc comments gci requires nolint directives to appear immediately before the func declaration, after the doc comment block. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 443d2fa commit aa68b58

4 files changed

Lines changed: 468 additions & 33 deletions

File tree

pkg/api/v1/registry.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414

1515
"github.com/go-chi/chi/v5"
1616

17-
"github.com/stacklok/toolhive-core/registry/types"
17+
registry "github.com/stacklok/toolhive-core/registry/types"
1818
"github.com/stacklok/toolhive/pkg/config"
1919
regpkg "github.com/stacklok/toolhive/pkg/registry"
2020
"github.com/stacklok/toolhive/pkg/registry/auth"
@@ -286,7 +286,7 @@ type RegistryRoutes struct {
286286
// NewRegistryRoutes creates a new RegistryRoutes with the default config provider
287287
func NewRegistryRoutes() *RegistryRoutes {
288288
return &RegistryRoutes{
289-
configProvider: config.NewDefaultProvider(),
289+
configProvider: config.NewProvider(),
290290
configService: regpkg.NewConfigurator(),
291291
}
292292
}
@@ -304,7 +304,7 @@ func NewRegistryRoutesWithProvider(provider config.Provider) *RegistryRoutes {
304304
// In serve mode, the registry provider uses non-interactive auth (no browser OAuth).
305305
func NewRegistryRoutesForServe() *RegistryRoutes {
306306
return &RegistryRoutes{
307-
configProvider: config.NewDefaultProvider(),
307+
configProvider: config.NewProvider(),
308308
configService: regpkg.NewConfigurator(),
309309
serveMode: true,
310310
}
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package v1
5+
6+
import (
7+
"context"
8+
"encoding/json"
9+
"net/http"
10+
"net/http/httptest"
11+
"os"
12+
"path/filepath"
13+
"testing"
14+
15+
"github.com/go-chi/chi/v5"
16+
"github.com/stretchr/testify/assert"
17+
"github.com/stretchr/testify/require"
18+
"gopkg.in/yaml.v3"
19+
20+
"github.com/stacklok/toolhive/pkg/config"
21+
"github.com/stacklok/toolhive/pkg/registry"
22+
)
23+
24+
// writeFactorySentinelRegistry creates a legacy-format registry JSON file with a
25+
// single server named sentinelName and a YAML config pointing to it.
26+
// Returns the config file path.
27+
func writeFactorySentinelRegistry(t *testing.T, sentinelName string) string {
28+
t.Helper()
29+
30+
dir := t.TempDir()
31+
32+
// Write legacy registry JSON with the sentinel server.
33+
type serverEntry struct {
34+
Image string `json:"image"`
35+
Description string `json:"description"`
36+
}
37+
type registryFile struct {
38+
Version string `json:"version"`
39+
LastUpdated string `json:"last_updated"`
40+
Servers map[string]serverEntry `json:"servers"`
41+
}
42+
43+
regData, err := json.Marshal(registryFile{
44+
Version: "1.0.0",
45+
LastUpdated: "2025-01-01T00:00:00Z",
46+
Servers: map[string]serverEntry{
47+
sentinelName: {
48+
Image: "factory/server:latest",
49+
Description: "Factory sentinel server",
50+
},
51+
},
52+
})
53+
require.NoError(t, err)
54+
55+
registryPath := filepath.Join(dir, "registry.json")
56+
require.NoError(t, os.WriteFile(registryPath, regData, 0600))
57+
58+
// Write YAML config pointing to the registry JSON.
59+
type configFile struct {
60+
LocalRegistryPath string `yaml:"local_registry_path"`
61+
}
62+
63+
cfgData, err := yaml.Marshal(configFile{LocalRegistryPath: registryPath})
64+
require.NoError(t, err)
65+
66+
configPath := filepath.Join(dir, "config.yaml")
67+
require.NoError(t, os.WriteFile(configPath, cfgData, 0600))
68+
69+
return configPath
70+
}
71+
72+
// makeListServersRequest builds an httptest request for GET /{name}/servers
73+
// with the chi URL param "name" set to registryName.
74+
func makeListServersRequest(registryName string) *http.Request {
75+
req := httptest.NewRequest(http.MethodGet, "/"+registryName+"/servers", nil)
76+
rctx := chi.NewRouteContext()
77+
rctx.URLParams.Add("name", registryName)
78+
return req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
79+
}
80+
81+
// TestNewRegistryRoutes_RespectsRegisteredFactory is the critical regression test
82+
// for the bug fix. Before the fix, NewRegistryRoutes called config.NewDefaultProvider(),
83+
// which bypassed any registered ProviderFactory. The fix changed it to call
84+
// config.NewProvider(), which checks the factory first.
85+
//
86+
// The test registers a factory that returns a PathProvider pointing at a local
87+
// registry JSON containing a sentinel server name. If NewRegistryRoutes correctly
88+
// forwards the factory-backed provider to getCurrentProvider, the listServers
89+
// handler will return that sentinel server in its response.
90+
//
91+
//nolint:paralleltest // Mutates global state: config.registeredFactory and regpkg.defaultProviderOnce
92+
func TestNewRegistryRoutes_RespectsRegisteredFactory(t *testing.T) {
93+
const sentinelName = "factory-sentinel-server"
94+
95+
configPath := writeFactorySentinelRegistry(t, sentinelName)
96+
97+
config.RegisterProviderFactory(func() config.Provider {
98+
return config.NewPathProvider(configPath)
99+
})
100+
t.Cleanup(func() {
101+
config.RegisterProviderFactory(nil)
102+
registry.ResetDefaultProvider()
103+
})
104+
105+
routes := NewRegistryRoutes()
106+
107+
// Clear provider cache so getCurrentProvider re-initialises using our factory.
108+
registry.ResetDefaultProvider()
109+
110+
w := httptest.NewRecorder()
111+
routes.listServers(w, makeListServersRequest("default"))
112+
113+
assert.Equal(t, http.StatusOK, w.Code,
114+
"listServers should return 200 when factory-backed provider is used")
115+
116+
var body listServersResponse
117+
require.NoError(t, json.NewDecoder(w.Body).Decode(&body),
118+
"response body should be valid JSON")
119+
120+
names := make([]string, 0, len(body.Servers))
121+
for _, s := range body.Servers {
122+
names = append(names, s.GetName())
123+
}
124+
assert.Contains(t, names, sentinelName,
125+
"sentinel server must be present; this would fail on the old code that called config.NewDefaultProvider()")
126+
}
127+
128+
// TestNewRegistryRoutesForServe_RespectsRegisteredFactory verifies that the
129+
// serve-mode constructor also honours the registered ProviderFactory. This
130+
// mirrors TestNewRegistryRoutes_RespectsRegisteredFactory but exercises
131+
// NewRegistryRoutesForServe and the serveMode code path.
132+
//
133+
//nolint:paralleltest // Mutates global state: config.registeredFactory and regpkg.defaultProviderOnce
134+
func TestNewRegistryRoutesForServe_RespectsRegisteredFactory(t *testing.T) {
135+
const sentinelName = "factory-sentinel-server"
136+
137+
configPath := writeFactorySentinelRegistry(t, sentinelName)
138+
139+
config.RegisterProviderFactory(func() config.Provider {
140+
return config.NewPathProvider(configPath)
141+
})
142+
t.Cleanup(func() {
143+
config.RegisterProviderFactory(nil)
144+
registry.ResetDefaultProvider()
145+
})
146+
147+
routes := NewRegistryRoutesForServe()
148+
149+
// Clear provider cache so getCurrentProvider re-initialises using our factory.
150+
registry.ResetDefaultProvider()
151+
152+
w := httptest.NewRecorder()
153+
routes.listServers(w, makeListServersRequest("default"))
154+
155+
assert.Equal(t, http.StatusOK, w.Code,
156+
"listServers should return 200 when factory-backed provider is used in serve mode")
157+
158+
var body listServersResponse
159+
require.NoError(t, json.NewDecoder(w.Body).Decode(&body),
160+
"response body should be valid JSON")
161+
162+
names := make([]string, 0, len(body.Servers))
163+
for _, s := range body.Servers {
164+
names = append(names, s.GetName())
165+
}
166+
assert.Contains(t, names, sentinelName,
167+
"sentinel server must be present; this would fail on the old code that called config.NewDefaultProvider()")
168+
}
169+
170+
// TestNewRegistryRoutes_NoFactory_ReturnsValidRoutes verifies that NewRegistryRoutes
171+
// returns a fully-initialised struct when no ProviderFactory is registered.
172+
//
173+
//nolint:paralleltest // Mutates global state: config.registeredFactory
174+
func TestNewRegistryRoutes_NoFactory_ReturnsValidRoutes(t *testing.T) {
175+
config.RegisterProviderFactory(nil)
176+
t.Cleanup(func() { config.RegisterProviderFactory(nil) })
177+
178+
routes := NewRegistryRoutes()
179+
180+
require.NotNil(t, routes, "NewRegistryRoutes must return a non-nil value")
181+
assert.NotNil(t, routes.configProvider, "configProvider must be initialised")
182+
assert.NotNil(t, routes.configService, "configService must be initialised")
183+
assert.False(t, routes.serveMode, "serveMode must be false for NewRegistryRoutes")
184+
}
185+
186+
// TestNewRegistryRoutesForServe_NoFactory_ReturnsValidRoutes verifies that
187+
// NewRegistryRoutesForServe returns a fully-initialised struct with serveMode
188+
// set to true when no ProviderFactory is registered.
189+
//
190+
//nolint:paralleltest // Mutates global state: config.registeredFactory
191+
func TestNewRegistryRoutesForServe_NoFactory_ReturnsValidRoutes(t *testing.T) {
192+
config.RegisterProviderFactory(nil)
193+
t.Cleanup(func() { config.RegisterProviderFactory(nil) })
194+
195+
routes := NewRegistryRoutesForServe()
196+
197+
require.NotNil(t, routes, "NewRegistryRoutesForServe must return a non-nil value")
198+
assert.NotNil(t, routes.configProvider, "configProvider must be initialised")
199+
assert.NotNil(t, routes.configService, "configService must be initialised")
200+
assert.True(t, routes.serveMode, "serveMode must be true for NewRegistryRoutesForServe")
201+
}

pkg/registry/factory.go

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,30 @@ import (
1010
"fmt"
1111
"log/slog"
1212
"sync"
13+
"sync/atomic"
1314

1415
"github.com/stacklok/toolhive/pkg/config"
1516
"github.com/stacklok/toolhive/pkg/registry/auth"
1617
"github.com/stacklok/toolhive/pkg/secrets"
1718
)
1819

19-
var (
20-
defaultProvider Provider
21-
defaultProviderOnce sync.Once
22-
defaultProviderErr error
23-
// defaultProviderMu protects the ResetDefaultProvider operation
24-
// to prevent race conditions when resetting the sync.Once.
25-
// The mutex is NOT needed for GetDefaultProviderWithConfig since
26-
// sync.Once already provides thread-safety for initialization.
27-
defaultProviderMu sync.Mutex
28-
)
20+
// providerState groups the sync.Once with the values it initialises.
21+
// Storing all three together behind an atomic pointer means ResetDefaultProvider
22+
// can swap in a fresh struct without ever writing to a struct that another
23+
// goroutine may be reading — eliminating the data race between Reset and Do.
24+
type providerState struct {
25+
once sync.Once
26+
provider Provider
27+
err error
28+
}
29+
30+
// currentProviderState is the live singleton state. Replaced atomically by
31+
// ResetDefaultProvider; never mutated after creation except inside once.Do.
32+
var currentProviderState atomic.Pointer[providerState]
33+
34+
func init() {
35+
currentProviderState.Store(&providerState{})
36+
}
2937

3038
// ProviderOption configures optional behavior for NewRegistryProvider.
3139
type ProviderOption func(*providerOptions)
@@ -79,42 +87,46 @@ func NewRegistryProvider(cfg *config.Config, opts ...ProviderOption) (Provider,
7987
return NewLocalRegistryProvider(), nil
8088
}
8189

82-
// GetDefaultProvider returns the default registry provider instance
83-
// This maintains backward compatibility with the existing singleton pattern
90+
// GetDefaultProvider returns the default registry provider instance.
91+
// config.NewProvider() is called inside the sync.Once closure so that any
92+
// registered ProviderFactory is invoked at most once, not on every call.
8493
func GetDefaultProvider() (Provider, error) {
85-
return GetDefaultProviderWithConfig(config.NewDefaultProvider())
94+
s := currentProviderState.Load()
95+
s.once.Do(func() {
96+
cfg, err := config.NewProvider().LoadOrCreateConfig()
97+
if err != nil {
98+
s.err = err
99+
return
100+
}
101+
s.provider, s.err = NewRegistryProvider(cfg)
102+
})
103+
return s.provider, s.err
86104
}
87105

88106
// GetDefaultProviderWithConfig returns a registry provider using the given config provider.
89107
// This allows tests to inject their own config provider.
90108
// The interactive flag controls whether browser-based OAuth flows are allowed.
91109
// Pass true for CLI contexts, false for headless/serve mode.
92110
func GetDefaultProviderWithConfig(configProvider config.Provider, opts ...ProviderOption) (Provider, error) {
93-
defaultProviderOnce.Do(func() {
111+
s := currentProviderState.Load()
112+
s.once.Do(func() {
94113
cfg, err := configProvider.LoadOrCreateConfig()
95114
if err != nil {
96-
defaultProviderErr = err
115+
s.err = err
97116
return
98117
}
99-
defaultProvider, defaultProviderErr = NewRegistryProvider(cfg, opts...)
118+
s.provider, s.err = NewRegistryProvider(cfg, opts...)
100119
})
101-
102-
return defaultProvider, defaultProviderErr
120+
return s.provider, s.err
103121
}
104122

105-
// ResetDefaultProvider clears the cached default provider instance
106-
// This allows the provider to be recreated with updated configuration.
107-
// This function is thread-safe and can be called concurrently.
108-
// The mutex is required here because we're modifying the sync.Once itself,
109-
// which is not a thread-safe operation.
123+
// ResetDefaultProvider clears the cached default provider instance so the
124+
// next call to GetDefaultProvider or GetDefaultProviderWithConfig creates a
125+
// fresh one. The atomic swap is safe to call concurrently: goroutines that
126+
// already hold a reference to the old state finish against that state cleanly,
127+
// while goroutines that load after the swap initialise against the new state.
110128
func ResetDefaultProvider() {
111-
defaultProviderMu.Lock()
112-
defer defaultProviderMu.Unlock()
113-
114-
// Reset the sync.Once to allow re-initialization
115-
defaultProviderOnce = sync.Once{}
116-
defaultProvider = nil
117-
defaultProviderErr = nil
129+
currentProviderState.Store(&providerState{})
118130
}
119131

120132
// resolveTokenSource creates a TokenSource from the config if registry auth is configured.

0 commit comments

Comments
 (0)